diff --git a/stl/ast.py b/stl/ast.py index 4e0d3e4..da0537e 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -11,6 +11,7 @@ from sympy import Symbol dt_sym = Symbol('dt', positive=True) t_sym = Symbol('t', positive=True) + def flatten_binary(phi, op, dropT, shortT): f = lambda x: x.args if isinstance(x, op) else [x] args = [arg for arg in phi.args if arg is not dropT] @@ -23,7 +24,7 @@ def flatten_binary(phi, op, dropT, shortT): return args[0] else: return op(tuple(fn.mapcat(f, phi.args))) - + class AST(object): __slots__ = () @@ -44,7 +45,7 @@ class AST(object): class _Top(AST): __slots__ = () - + def __repr__(self): return "⊤" @@ -61,6 +62,7 @@ class _Bot(AST): def __invert__(self): return TOP + TOP = _Top() BOT = _Bot() @@ -70,7 +72,7 @@ class AtomicPred(namedtuple("AP", ["id"]), AST): def __repr__(self): return f"{self.id}" - + @property def children(self): return set() @@ -81,7 +83,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST): def __repr__(self): return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}" - + @property def children(self): return set() @@ -103,7 +105,7 @@ class Interval(namedtuple('I', ['lower', 'upper'])): def __repr__(self): return f"[{self.lower},{self.upper}]" - + @property def children(self): return {self.lower, self.upper} @@ -113,9 +115,10 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST): __slots__ = () OP = "?" + def __repr__(self): return f" {self.OP} ".join(f"({x})" for x in self.args) - + @property def children(self): return set(self.args) @@ -125,10 +128,12 @@ class Or(NaryOpSTL): __slots__ = () OP = "∨" + def __hash__(self): # TODO: compute hash based on contents return hash(repr(self)) + class And(NaryOpSTL): __slots__ = () @@ -144,7 +149,7 @@ class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST): def __repr__(self): return f"{self.OP}{self.interval}({self.arg})" - + @property def children(self): return {self.arg} @@ -158,6 +163,7 @@ class F(ModalOp): # TODO: compute hash based on contents return hash(repr(self)) + class G(ModalOp): __slots__ = () OP = "□" @@ -172,7 +178,7 @@ class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST): def __repr__(self): return f"({self.arg1}) U ({self.arg2})" - + @property def children(self): return {self.arg1, self.arg2} @@ -187,7 +193,7 @@ class Neg(namedtuple('Neg', ['arg']), AST): def __repr__(self): return f"¬({self.arg})" - + @property def children(self): return {self.arg} @@ -202,7 +208,7 @@ class Next(namedtuple('Next', ['arg']), AST): def __repr__(self): return f"X({self.arg})" - + @property def children(self): return {self.arg} diff --git a/stl/parser.py b/stl/parser.py index 67cc30b..3689973 100644 --- a/stl/parser.py +++ b/stl/parser.py @@ -62,7 +62,8 @@ _ = ~r"\s"+ __ = ~r"\s"* EOL = "\\n" ''') - + + class STLVisitor(NodeVisitor): def __init__(self, H=float('inf')): super().__init__() @@ -78,7 +79,7 @@ class STLVisitor(NodeVisitor): visit_paren_phi = partialmethod(children_getter, i=2) def visit_interval(self, _, children): - _, _, (left,), _, _, _, (right,), _, _ = children + _, _, (left, ), _, _, _, (right, ), _, _ = children left = left if left != [] else float("inf") right = right if right != [] else float("inf") if isinstance(left, int): @@ -92,7 +93,7 @@ class STLVisitor(NodeVisitor): def visit_unbound(self, node, _): return Symbol(node.text) - + visit_op = get_text def unary_temp_op_visitor(self, _, children, op): @@ -101,13 +102,13 @@ class STLVisitor(NodeVisitor): return op(i, phi) def binop_visitor(self, _, children, op): - phi1, _, _, _, (phi2,) = children + phi1, _, _, _, (phi2, ) = children argL = list(phi1.args) if isinstance(phi1, op) else [phi1] argR = list(phi2.args) if isinstance(phi2, op) else [phi2] return op(tuple(argL + argR)) def sugar_binop_visitor(self, _, children, op): - phi1, _, _, _, (phi2,) = children + phi1, _, _, _, (phi2, ) = children return op(phi1, phi2) visit_f = partialmethod(unary_temp_op_visitor, op=ast.F) @@ -136,7 +137,7 @@ class STLVisitor(NodeVisitor): coeffs, (iden, time) = children c = coeffs[0] if coeffs else Number(1) return ast.Var(coeff=c, id=iden, time=time) - + def visit_coeff(self, _, children): dt, coeff, *_ = children[0] if not isinstance(dt, Symbol): @@ -147,7 +148,7 @@ class STLVisitor(NodeVisitor): def visit_terms(self, _, children): if isinstance(children[0], list): - term, _1, sgn ,_2, terms = children[0] + term, _1, sgn, _2, terms = children[0] terms = lens(terms)[0].coeff * sgn return [term] + terms else: @@ -170,5 +171,5 @@ class STLVisitor(NodeVisitor): return ast.Next(children[1]) -def parse(stl_str:str, rule:str="phi", H=float('inf')) -> "STL": +def parse(stl_str: str, rule: str = "phi", H=float('inf')) -> "STL": return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str)) diff --git a/stl/utils.py b/stl/utils.py index fda3f22..41d1ce7 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -16,7 +16,8 @@ from stl.types import STL, STL_Generator, MTL Lens = TypeVar('Lens') -def walk(phi:STL) -> STL_Generator: + +def walk(phi: STL) -> STL_Generator: """Walk of the AST.""" pop = deque.pop children = deque([phi]) @@ -25,16 +26,18 @@ def walk(phi:STL) -> STL_Generator: yield node children.extend(node.children) + def vars_in_phi(phi): focus = stl.terms_lens(phi) return set(focus.tuple_(lens.id, lens.time).get_all()) -def type_pred(*args:List[Type]) -> Mapping[Type, bool]: + +def type_pred(*args: List[Type]) -> Mapping[Type, bool]: ast_types = set(args) return lambda x: type(x) in ast_types -def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens: +def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None) -> Lens: if focus_lens is None: focus_lens = lambda _: [lens] if pred is None: @@ -42,10 +45,11 @@ def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens: l = lenses.bind(phi) if bind else lens return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens)) -def _ast_lens(phi:STL, pred, focus_lens) -> Lens: + +def _ast_lens(phi: STL, pred, focus_lens) -> Lens: if pred(phi): yield from focus_lens(phi) - + if phi is None or not phi.children: return @@ -54,36 +58,45 @@ def _ast_lens(phi:STL, pred, focus_lens) -> Lens: elif isinstance(phi, stl.ast.Until): child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')] elif isinstance(phi, NaryOpSTL): - child_lenses = [lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)] + child_lenses = [ + lens.GetAttr('args')[j] for j, _ in enumerate(phi.args) + ] else: child_lenses = [lens.GetAttr('arg')] for l in child_lenses: yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)] - + lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq)) AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred)) and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or)) -def terms_lens(phi:STL, bind:bool=True) -> Lens: - return lineq_lens(phi, bind).terms.each_() + +def terms_lens(phi: STL, bind: bool = True) -> Lens: + return lineq_lens(phi, bind).Each().terms.Each() -def param_lens(phi:STL) -> Lens: +def param_lens(phi: STL) -> Lens: is_sym = lambda x: isinstance(x, sympy.Symbol) - def focus_lens(leaf): - return [lens.const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] - return ast_lens(phi, pred=type_pred(LinEq, F, G), - focus_lens=focus_lens).filter_(is_sym) + def focus_lens(leaf): + return [lens.const] if isinstance(leaf, LinEq) else [ + lens.GetAttr('interval')[0], + lens.GetAttr('interval')[1] + ] + + return ast_lens( + phi, pred=type_pred(LinEq, F, G), + focus_lens=focus_lens).filter_(is_sym) def set_params(stl_or_lens, val) -> STL: - l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens) + l = stl_or_lens if isinstance(stl_or_lens, + Lens) else param_lens(stl_or_lens) return l.modify(lambda x: float(val.get(x, val.get(str(x), x)))) -def f_neg_or_canonical_form(phi:STL) -> STL: +def f_neg_or_canonical_form(phi: STL) -> STL: if isinstance(phi, LinEq): return phi @@ -106,24 +119,14 @@ def f_neg_or_canonical_form(phi:STL) -> STL: raise NotImplementedError -def to_mtl(phi:STL) -> MTL: - focus = lineq_lens(phi) - to_ap = lambda i: stl.ast.AtomicPred("AP{}".format(i)) - ap_map = {to_ap(i): leq for i, leq in enumerate(focus.get_all())} - lineq_map = {v:k for k,v in ap_map.items()} - return focus.modify(lineq_map.get), ap_map - - -def from_mtl(phi:MTL, ap_map:Dict[AtomicPred, LinEq]) -> STL: - focus = AP_lens(phi) - return focus.modify(ap_map.get) - def _lineq_lipschitz(lineq): - return sum(map(abs, lens(lineq).terms.each_().coeff.get_all())) + return sum(map(abs, lens(lineq).Each().terms.Each().coeff.collect())) + def linear_stl_lipschitz(phi): """Infinity norm lipschitz bound of linear inequality predicate.""" - return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all()))) + return float(max(map(_lineq_lipschitz, lineq_lens(phi).Each().collect()))) + def inline_context(phi, context): phi2 = None @@ -133,6 +136,7 @@ def inline_context(phi, context): # TODO: this is hack to flatten the AST. Fix! return stl.parse(str(phi)) + op_lookup = { ">": op.gt, ">=": op.ge, @@ -141,6 +145,7 @@ op_lookup = { "=": op.eq, } + def get_times(x): times = set.union(*({t for t, _ in v.items()} for v in x.values())) return sorted(times) @@ -151,18 +156,19 @@ def eval_lineq(lineq, x, times=None, compact=True): times = get_times(x) def eval_term(term, t): - return float(term.coeff)*x[term.id.name][t] + return float(term.coeff) * x[term.id.name][t] output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1])) terms = lens(lineq).Each().terms.Each().collect() for t in times: lhs = sum(eval_term(term, t) for term in terms) output[t] = op_lookup[lineq.op](lhs, lineq.const) - + if compact: output.compact() return output + def eval_lineqs(phi, x, times=None): if times is None: times = get_times(x) @@ -172,26 +178,34 @@ def eval_lineqs(phi, x, times=None): # EDSL + def alw(phi, *, lo, hi): return G(Interval(lo, hi), phi) + def env(phi, *, lo, hi): return F(Interval(lo, hi), phi) + def until(phi1, phi2, *, lo, hi): return stl.ast.Until(Interval(lo, hi), phi1, phi2) + def andf(*args): return reduce(op.and_, args) if args else stl.TOP + def orf(*args): return reduce(op.or_, args) if args else stl.TOP + def implies(x, y): return ~x | y + def xor(x, y): return (x | y) & ~(x & y) + def iff(x, y): return (x & y) | (~x & ~y)