added list params utility function

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-30 14:31:23 -07:00
parent 73301b801f
commit c18cd58cb5
2 changed files with 19 additions and 5 deletions

View file

@ -7,7 +7,7 @@ from functools import singledispatch
import funcy as fn import funcy as fn
import stl import stl
import stl.ast import stl.ast
from lenses import lens from lenses import bind
oo = float('inf') oo = float('inf')
@ -129,7 +129,7 @@ def eval_stl_lineq(lineq):
def eval_terms(lineq, x, t): def eval_terms(lineq, x, t):
terms = lens(lineq).terms.each_().get_all() terms = bind(lineq).terms.Each().collect()
return sum(eval_term(term, x, t) for term in terms) return sum(eval_term(term, x, t) for term in terms)

View file

@ -10,7 +10,7 @@ import lenses
import stl.ast import stl.ast
from lenses import lens from lenses import lens
from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL, from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL,
Neg, Or, Param) Neg, Or, Param, ModalOp)
from stl.types import STL, STL_Generator from stl.types import STL, STL_Generator
Lens = TypeVar('Lens') Lens = TypeVar('Lens')
@ -25,6 +25,19 @@ def walk(phi: STL) -> STL_Generator:
yield node yield node
children.extend(node.children) children.extend(node.children)
def list_params(phi: STL):
"""Walk of the AST."""
def get_params(leaf):
if isinstance(leaf, ModalOp):
if isinstance(leaf.interval[0], Param):
yield leaf.interval[0]
if isinstance(leaf.interval[1], Param):
yield leaf.interval[1]
elif isinstance(leaf, LinEq):
if isinstance(leaf.const, Param):
yield leaf.const
return set(fn.mapcat(get_params, walk(phi)))
def vars_in_phi(phi): def vars_in_phi(phi):
focus = stl.terms_lens(phi) focus = stl.terms_lens(phi)
@ -81,7 +94,7 @@ def terms_lens(phi: STL, bind: bool = True) -> Lens:
return lineq_lens(phi, bind).Each().terms.Each() return lineq_lens(phi, bind).Each().terms.Each()
def param_lens(phi: STL) -> Lens: def param_lens(phi: STL, *, getter=False) -> Lens:
def focus_lens(leaf): def focus_lens(leaf):
candidates = [lens.const] if isinstance(leaf, LinEq) else [ candidates = [lens.const] if isinstance(leaf, LinEq) else [
lens.GetAttr('interval')[0], lens.GetAttr('interval')[0],
@ -89,7 +102,8 @@ def param_lens(phi: STL) -> Lens:
] ]
return (x for x in candidates if isinstance(x.get()(leaf), Param)) return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens) return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens,
getter=getter)
def set_params(phi, val) -> STL: def set_params(phi, val) -> STL: