# -*- coding: utf-8 -*- # TODO: allow multiplication to be distributive # TODO: support variables on both sides of ineq # TODO: Allow -x = -1*x from functools import partialmethod from collections import namedtuple import operator as op from parsimonious import Grammar, NodeVisitor from funcy import flatten from lenses import lens from sympy import Symbol, Number from stl import ast from stl.utils import implies, xor, iff, env, alw STL_GRAMMAR = Grammar(u''' phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and / implies / xor / iff / paren_phi) paren_phi = "(" __ phi __ ")" or = paren_phi _ ("∨" / "or" / "|") _ (or / paren_phi) and = paren_phi _ ("∧" / "and" / "&") _ (and / paren_phi) implies = paren_phi _ ("→" / "->") _ (and / paren_phi) iff = paren_phi _ ("⇔" / "<->" / "iff") _ (and / paren_phi) xor = paren_phi _ ("⊕" / "^" / "xor") _ (and / paren_phi) neg = ("~" / "¬") paren_phi next = "X" paren_phi f = F interval? phi g = G interval? phi until = paren_phi __ U __ paren_phi timed_until = paren_phi __ U interval __ paren_phi F = "F" / "◇" G = "G" / "□" U = "U" interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]" const_or_unbound = unbound / "inf" / const lineq = terms _ op _ const_or_unbound term = coeff? var coeff = ((dt __ "*" __)? const __ "*" __) / (dt __ "*") terms = (term __ pm __ terms) / term var = id AP = ~r"[a-zA-z\d]+" pm = "+" / "-" dt = "dt" unbound = id "?" id = ~r"[a-zA-z\d]+" const = ~r"[\+\-]?\d*(\.\d+)?" op = ">=" / "<=" / "<" / ">" / "=" _ = ~r"\s"+ __ = ~r"\s"* EOL = "\\n" ''') class STLVisitor(NodeVisitor): def __init__(self, H=float('inf')): super().__init__() self.default_interval = ast.Interval(0.0, H) def generic_visit(self, _, children): return children def children_getter(self, _, children, i): return children[i] visit_phi = partialmethod(children_getter, i=0) visit_paren_phi = partialmethod(children_getter, i=2) def visit_interval(self, _, children): _, _, (left, ), _, _, _, (right, ), _, _ = children left = left if left != [] else float("inf") right = right if right != [] else float("inf") if isinstance(left, int): left = float(left) if isinstance(right, int): left = float(right) return ast.Interval(left, right) def get_text(self, node, _): return node.text def visit_unbound(self, node, _): return Symbol(node.text) visit_op = get_text def unary_temp_op_visitor(self, _, children, op): _, i, phi = children i = self.default_interval if not i else i[0] return op(i, phi) def binop_visitor(self, _, children, op): phi1, _, _, _, (phi2, ) = children argL = list(phi1.args) if isinstance(phi1, op) else [phi1] argR = list(phi2.args) if isinstance(phi2, op) else [phi2] return op(tuple(argL + argR)) def sugar_binop_visitor(self, _, children, op): phi1, _, _, _, (phi2, ) = children return op(phi1, phi2) visit_f = partialmethod(unary_temp_op_visitor, op=ast.F) visit_g = partialmethod(unary_temp_op_visitor, op=ast.G) visit_or = partialmethod(binop_visitor, op=ast.Or) visit_and = partialmethod(binop_visitor, op=ast.And) visit_xor = partialmethod(sugar_binop_visitor, op=xor) visit_iff = partialmethod(sugar_binop_visitor, op=iff) visit_implies = partialmethod(sugar_binop_visitor, op=implies) def visit_until(self, _, children): phi1, _, _, _, phi2 = children return ast.Until(phi1, phi2) def visit_timed_until(self, _, children): phi, _, _, (lo, hi), _, psi = children return env(psi, lo=lo, hi=hi) & alw(ast.Until(phi, psi), lo=0, hi=lo) def visit_id(self, name, _): return Symbol(name.text) def visit_const(self, const, children): return float(const.text) def visit_term(self, _, children): coeffs, (iden, time) = children c = coeffs[0] if coeffs else Number(1) return ast.Var(coeff=c, id=iden, time=time) def visit_coeff(self, _, children): dt, coeff, *_ = children[0] if not isinstance(dt, Symbol): dt = dt[0][0] if dt else Number(1) return dt * coeff else: return dt def visit_terms(self, _, children): if isinstance(children[0], list): term, _1, sgn, _2, terms = children[0] terms = lens(terms)[0].coeff * sgn return [term] + terms else: return children def visit_lineq(self, _, children): terms, _1, op, _2, const = children return ast.LinEq(tuple(terms), op, const[0]) def visit_pm(self, node, _): return Number(1) if node.text == "+" else Number(-1) def visit_AP(self, *args): return ast.AtomicPred(self.visit_id(*args)) def visit_neg(self, _, children): return ast.Neg(children[1]) def visit_next(self, _, children): return ast.Next(children[1]) def parse(stl_str: str, rule: str = "phi", H=float('inf')) -> "STL": return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str))