added list params utility function

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-30 14:31:23 -07:00
parent 73301b801f
commit c18cd58cb5
2 changed files with 19 additions and 5 deletions

View file

@ -7,7 +7,7 @@ from functools import singledispatch
import funcy as fn
import stl
import stl.ast
from lenses import lens
from lenses import bind
oo = float('inf')
@ -129,7 +129,7 @@ def eval_stl_lineq(lineq):
def eval_terms(lineq, x, t):
terms = lens(lineq).terms.each_().get_all()
terms = bind(lineq).terms.Each().collect()
return sum(eval_term(term, x, t) for term in terms)

View file

@ -10,7 +10,7 @@ import lenses
import stl.ast
from lenses import lens
from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL,
Neg, Or, Param)
Neg, Or, Param, ModalOp)
from stl.types import STL, STL_Generator
Lens = TypeVar('Lens')
@ -25,6 +25,19 @@ def walk(phi: STL) -> STL_Generator:
yield node
children.extend(node.children)
def list_params(phi: STL):
"""Walk of the AST."""
def get_params(leaf):
if isinstance(leaf, ModalOp):
if isinstance(leaf.interval[0], Param):
yield leaf.interval[0]
if isinstance(leaf.interval[1], Param):
yield leaf.interval[1]
elif isinstance(leaf, LinEq):
if isinstance(leaf.const, Param):
yield leaf.const
return set(fn.mapcat(get_params, walk(phi)))
def vars_in_phi(phi):
focus = stl.terms_lens(phi)
@ -81,7 +94,7 @@ def terms_lens(phi: STL, bind: bool = True) -> Lens:
return lineq_lens(phi, bind).Each().terms.Each()
def param_lens(phi: STL) -> Lens:
def param_lens(phi: STL, *, getter=False) -> Lens:
def focus_lens(leaf):
candidates = [lens.const] if isinstance(leaf, LinEq) else [
lens.GetAttr('interval')[0],
@ -89,7 +102,8 @@ def param_lens(phi: STL) -> Lens:
]
return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens)
return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens,
getter=getter)
def set_params(phi, val) -> STL: