This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-30 15:00:55 -07:00
parent c18cd58cb5
commit 2640728288

View file

@ -9,8 +9,8 @@ import traces
import lenses import lenses
import stl.ast import stl.ast
from lenses import lens from lenses import lens
from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL, from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL, Neg, Or,
Neg, Or, Param, ModalOp) Param, ModalOp)
from stl.types import STL, STL_Generator from stl.types import STL, STL_Generator
Lens = TypeVar('Lens') Lens = TypeVar('Lens')
@ -25,8 +25,10 @@ def walk(phi: STL) -> STL_Generator:
yield node yield node
children.extend(node.children) children.extend(node.children)
def list_params(phi: STL): def list_params(phi: STL):
"""Walk of the AST.""" """Walk of the AST."""
def get_params(leaf): def get_params(leaf):
if isinstance(leaf, ModalOp): if isinstance(leaf, ModalOp):
if isinstance(leaf.interval[0], Param): if isinstance(leaf.interval[0], Param):
@ -36,6 +38,7 @@ def list_params(phi: STL):
elif isinstance(leaf, LinEq): elif isinstance(leaf, LinEq):
if isinstance(leaf.const, Param): if isinstance(leaf.const, Param):
yield leaf.const yield leaf.const
return set(fn.mapcat(get_params, walk(phi))) return set(fn.mapcat(get_params, walk(phi)))
@ -52,10 +55,12 @@ def type_pred(*args: List[Type]) -> Mapping[Type, bool]:
def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None, def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None,
getter=False) -> Lens: getter=False) -> Lens:
if focus_lens is None: if focus_lens is None:
def focus_lens(_): def focus_lens(_):
return [lens] return [lens]
if pred is None: if pred is None:
def pred(_): def pred(_):
return False return False
@ -102,8 +107,8 @@ def param_lens(phi: STL, *, getter=False) -> Lens:
] ]
return (x for x in candidates if isinstance(x.get()(leaf), Param)) return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, return ast_lens(
getter=getter) phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
def set_params(phi, val) -> STL: def set_params(phi, val) -> STL: