diff --git a/stl/boolean_eval.py b/stl/boolean_eval.py index 6693d56..0ea707b 100644 --- a/stl/boolean_eval.py +++ b/stl/boolean_eval.py @@ -7,7 +7,7 @@ from functools import singledispatch import funcy as fn import stl import stl.ast -from lenses import lens +from lenses import bind oo = float('inf') @@ -129,7 +129,7 @@ def eval_stl_lineq(lineq): def eval_terms(lineq, x, t): - terms = lens(lineq).terms.each_().get_all() + terms = bind(lineq).terms.Each().collect() return sum(eval_term(term, x, t) for term in terms) diff --git a/stl/utils.py b/stl/utils.py index a9ea193..93e06a9 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -10,7 +10,7 @@ import lenses import stl.ast from lenses import lens from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL, - Neg, Or, Param) + Neg, Or, Param, ModalOp) from stl.types import STL, STL_Generator Lens = TypeVar('Lens') @@ -25,6 +25,19 @@ def walk(phi: STL) -> STL_Generator: yield node children.extend(node.children) +def list_params(phi: STL): + """Walk of the AST.""" + def get_params(leaf): + if isinstance(leaf, ModalOp): + if isinstance(leaf.interval[0], Param): + yield leaf.interval[0] + if isinstance(leaf.interval[1], Param): + yield leaf.interval[1] + elif isinstance(leaf, LinEq): + if isinstance(leaf.const, Param): + yield leaf.const + return set(fn.mapcat(get_params, walk(phi))) + def vars_in_phi(phi): focus = stl.terms_lens(phi) @@ -81,7 +94,7 @@ def terms_lens(phi: STL, bind: bool = True) -> Lens: return lineq_lens(phi, bind).Each().terms.Each() -def param_lens(phi: STL) -> Lens: +def param_lens(phi: STL, *, getter=False) -> Lens: def focus_lens(leaf): candidates = [lens.const] if isinstance(leaf, LinEq) else [ lens.GetAttr('interval')[0], @@ -89,7 +102,8 @@ def param_lens(phi: STL) -> Lens: ] return (x for x in candidates if isinstance(x.get()(leaf), Param)) - return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens) + return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, + getter=getter) def set_params(phi, val) -> STL: