# -*- coding: utf-8 -*- # TODO: break out into seperate library # TODO: allow multiplication to be distributive # TODO: support reference specific time points # TODO: add support for parsing Until # TODO: support variables on both sides of ineq # TODO: Allow -x = -1*x from functools import partialmethod from collections import namedtuple import operator as op from parsimonious import Grammar, NodeVisitor from funcy import flatten from lenses import lens from sympy import Symbol, Number from stl import ast from stl.utils import implies, xor, iff STL_GRAMMAR = Grammar(u''' phi = (neg / g / f / until / lineq / AP / or / and / implies / xor / iff / paren_phi) paren_phi = "(" __ phi __ ")" or = paren_phi _ ("∨" / "or" / "|") _ (or / paren_phi) and = paren_phi _ ("∧" / "and" / "&") _ (and / paren_phi) implies = paren_phi _ ("→" / "->") _ (and / paren_phi) iff = paren_phi _ ("⇔" / "<->" / "iff") _ (and / paren_phi) xor = paren_phi _ ("⊕" / "^" / "xor") _ (and / paren_phi) neg = ("~" / "¬") phi f = F interval? phi g = G interval? phi until = "(" __ phi _ U interval? _ phi __ ")" F = "F" / "◇" G = "G" / "□" U = "U" interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]" const_or_unbound = unbound / const lineq = terms _ op _ const_or_unbound term = coeff? var coeff = ((dt __ "*" __)? const __ "*" __) / (dt __ "*") terms = (term __ pm __ terms) / term var = id time? time = prime / time_index time_index = "[" "t" __ pm __ const "]" prime = "'" AP = ~r"[a-zA-z\d]+" pm = "+" / "-" dt = "dt" unbound = id "?" id = ~r"[a-zA-z\d]+" const = ~r"[\+\-]?\d*(\.\d+)?" op = ">=" / "<=" / "<" / ">" / "=" _ = ~r"\s"+ __ = ~r"\s"* EOL = "\\n" ''') class STLVisitor(NodeVisitor): def __init__(self, H=float('inf')): super().__init__() self.default_interval = ast.Interval(0, H) def generic_visit(self, _, children): return children def children_getter(self, _, children, i): return children[i] visit_phi = partialmethod(children_getter, i=0) visit_paren_phi = partialmethod(children_getter, i=2) def visit_interval(self, _, children): _, _, left, _, _, _, right, _, _ = children return ast.Interval(left[0], right[0]) def get_text(self, node, _): return node.text def visit_unbound(self, node, _): return Symbol(node.text) visit_op = get_text def unary_temp_op_visitor(self, _, children, op): _, i, phi = children i = self.default_interval if not i else i[0] return op(i, phi) def binop_visitor(self, _, children, op): phi1, _, _, _, (phi2,) = children argL = list(phi1.args) if isinstance(phi1, op) else [phi1] argR = list(phi2.args) if isinstance(phi2, op) else [phi2] return op(tuple(argL + argR)) def sugar_binop_visitor(self, _, children, op): phi1, _, _, _, (phi2,) = children return op(phi1, phi2) visit_f = partialmethod(unary_temp_op_visitor, op=ast.F) visit_g = partialmethod(unary_temp_op_visitor, op=ast.G) visit_or = partialmethod(binop_visitor, op=ast.Or) visit_and = partialmethod(binop_visitor, op=ast.And) visit_xor = partialmethod(sugar_binop_visitor, op=xor) visit_iff = partialmethod(sugar_binop_visitor, op=iff) visit_implies = partialmethod(sugar_binop_visitor, op=implies) def visit_until(self, _, children): _, _, phi1, _, _, i, _, phi2, *_ = children i = self.default_interval if not i else i[0] return ast.Until(i, phi1, phi2) def visit_id(self, name, _): return Symbol(name.text) def visit_var(self, _, children): iden, time_node = children time_node = list(flatten(time_node)) time = time_node[0] if len(time_node) > 0 else ast.t_sym return iden, time def visit_time_index(self, _, children): return children[3]* children[5] def visit_prime(self, *_): return ast.t_sym + ast.dt_sym def visit_const(self, const, children): return float(const.text) def visit_dt(self, *_): return ast.dt_sym def visit_term(self, _, children): coeffs, (iden, time) = children c = coeffs[0] if coeffs else Number(1) return ast.Var(coeff=c, id=iden, time=time) def visit_coeff(self, _, children): dt, coeff, *_ = children[0] if not isinstance(dt, Symbol): dt = dt[0][0] if dt else Number(1) return dt * coeff else: return dt def visit_terms(self, _, children): if isinstance(children[0], list): term, _1, sgn ,_2, terms = children[0] terms = lens(terms)[0].coeff * sgn return [term] + terms else: return children def visit_lineq(self, _, children): terms, _1, op, _2, const = children return ast.LinEq(tuple(terms), op, const[0]) def visit_pm(self, node, _): return Number(1) if node.text == "+" else Number(-1) def visit_AP(self, node, _): return ast.AtomicPred(node.text, ast.t_sym) def visit_neg(self, _, children): return ast.Neg(children[1]) def parse(stl_str:str, rule:str="phi", H=float('inf')) -> "STL": return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str))