This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-30 15:00:55 -07:00
parent c18cd58cb5
commit 2640728288

View file

@ -9,8 +9,8 @@ import traces
import lenses
import stl.ast
from lenses import lens
from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL,
Neg, Or, Param, ModalOp)
from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL, Neg, Or,
Param, ModalOp)
from stl.types import STL, STL_Generator
Lens = TypeVar('Lens')
@ -25,8 +25,10 @@ def walk(phi: STL) -> STL_Generator:
yield node
children.extend(node.children)
def list_params(phi: STL):
"""Walk of the AST."""
def get_params(leaf):
if isinstance(leaf, ModalOp):
if isinstance(leaf.interval[0], Param):
@ -36,6 +38,7 @@ def list_params(phi: STL):
elif isinstance(leaf, LinEq):
if isinstance(leaf.const, Param):
yield leaf.const
return set(fn.mapcat(get_params, walk(phi)))
@ -52,10 +55,12 @@ def type_pred(*args: List[Type]) -> Mapping[Type, bool]:
def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None,
getter=False) -> Lens:
if focus_lens is None:
def focus_lens(_):
return [lens]
if pred is None:
def pred(_):
return False
@ -102,8 +107,8 @@ def param_lens(phi: STL, *, getter=False) -> Lens:
]
return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens,
getter=getter)
return ast_lens(
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
def set_params(phi, val) -> STL: