Fixed PSTL construction

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-26 16:14:45 -07:00
parent d5985406ad
commit eda63fd6f0
3 changed files with 38 additions and 30 deletions

View file

@ -11,7 +11,7 @@ import traces
import stl.ast
from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg,
AtomicPred)
AtomicPred, Param, AST)
from stl.types import STL, STL_Generator, MTL
Lens = TypeVar('Lens')
@ -43,7 +43,7 @@ def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
if pred is None:
pred = lambda _: False
l = lenses.bind(phi) if bind else lens
return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens))
return l.Fork(*_ast_lens(phi, pred=pred, focus_lens=focus_lens))
def _ast_lens(phi: STL, pred, focus_lens) -> Lens:
@ -77,22 +77,20 @@ def terms_lens(phi: STL, bind: bool = True) -> Lens:
def param_lens(phi: STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf):
return [lens.const] if isinstance(leaf, LinEq) else [
candidates = [lens.const] if isinstance(leaf, LinEq) else [
lens.GetAttr('interval')[0],
lens.GetAttr('interval')[1]
]
return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens(
phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym)
focus_lens=focus_lens)
def set_params(stl_or_lens, val) -> STL:
l = stl_or_lens if isinstance(stl_or_lens,
Lens) else param_lens(stl_or_lens)
def set_params(phi, val) -> STL:
l = param_lens(phi) if isinstance(phi, AST) else phi
return l.modify(lambda x: float(val.get(x, val.get(str(x), x))))