175 lines
5.1 KiB
Python
175 lines
5.1 KiB
Python
# -*- coding: utf-8 -*-
|
||
|
||
# TODO: allow multiplication to be distributive
|
||
# TODO: support variables on both sides of ineq
|
||
# TODO: Allow -x = -1*x
|
||
|
||
from functools import partialmethod
|
||
from collections import namedtuple
|
||
import operator as op
|
||
|
||
from parsimonious import Grammar, NodeVisitor
|
||
from funcy import flatten
|
||
from lenses import lens
|
||
|
||
from sympy import Symbol, Number
|
||
|
||
from stl import ast
|
||
from stl.utils import implies, xor, iff, env, alw
|
||
|
||
STL_GRAMMAR = Grammar(u'''
|
||
phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and
|
||
/ implies / xor / iff / paren_phi)
|
||
|
||
paren_phi = "(" __ phi __ ")"
|
||
|
||
or = paren_phi _ ("∨" / "or" / "|") _ (or / paren_phi)
|
||
and = paren_phi _ ("∧" / "and" / "&") _ (and / paren_phi)
|
||
implies = paren_phi _ ("→" / "->") _ (and / paren_phi)
|
||
iff = paren_phi _ ("⇔" / "<->" / "iff") _ (and / paren_phi)
|
||
xor = paren_phi _ ("⊕" / "^" / "xor") _ (and / paren_phi)
|
||
|
||
neg = ("~" / "¬") paren_phi
|
||
next = "X" paren_phi
|
||
f = F interval? phi
|
||
g = G interval? phi
|
||
until = paren_phi __ U __ paren_phi
|
||
timed_until = paren_phi __ U interval __ paren_phi
|
||
|
||
F = "F" / "◇"
|
||
G = "G" / "□"
|
||
U = "U"
|
||
|
||
interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]"
|
||
|
||
const_or_unbound = unbound / "inf" / const
|
||
|
||
lineq = terms _ op _ const_or_unbound
|
||
term = coeff? var
|
||
coeff = ((dt __ "*" __)? const __ "*" __) / (dt __ "*")
|
||
terms = (term __ pm __ terms) / term
|
||
|
||
var = id
|
||
AP = ~r"[a-zA-z\d]+"
|
||
|
||
pm = "+" / "-"
|
||
dt = "dt"
|
||
unbound = id "?"
|
||
id = ~r"[a-zA-z\d]+"
|
||
const = ~r"[\+\-]?\d*(\.\d+)?"
|
||
op = ">=" / "<=" / "<" / ">" / "="
|
||
_ = ~r"\s"+
|
||
__ = ~r"\s"*
|
||
EOL = "\\n"
|
||
''')
|
||
|
||
|
||
class STLVisitor(NodeVisitor):
|
||
def __init__(self, H=float('inf')):
|
||
super().__init__()
|
||
self.default_interval = ast.Interval(0.0, H)
|
||
|
||
def generic_visit(self, _, children):
|
||
return children
|
||
|
||
def children_getter(self, _, children, i):
|
||
return children[i]
|
||
|
||
visit_phi = partialmethod(children_getter, i=0)
|
||
visit_paren_phi = partialmethod(children_getter, i=2)
|
||
|
||
def visit_interval(self, _, children):
|
||
_, _, (left, ), _, _, _, (right, ), _, _ = children
|
||
left = left if left != [] else float("inf")
|
||
right = right if right != [] else float("inf")
|
||
if isinstance(left, int):
|
||
left = float(left)
|
||
if isinstance(right, int):
|
||
left = float(right)
|
||
return ast.Interval(left, right)
|
||
|
||
def get_text(self, node, _):
|
||
return node.text
|
||
|
||
def visit_unbound(self, node, _):
|
||
return Symbol(node.text)
|
||
|
||
visit_op = get_text
|
||
|
||
def unary_temp_op_visitor(self, _, children, op):
|
||
_, i, phi = children
|
||
i = self.default_interval if not i else i[0]
|
||
return op(i, phi)
|
||
|
||
def binop_visitor(self, _, children, op):
|
||
phi1, _, _, _, (phi2, ) = children
|
||
argL = list(phi1.args) if isinstance(phi1, op) else [phi1]
|
||
argR = list(phi2.args) if isinstance(phi2, op) else [phi2]
|
||
return op(tuple(argL + argR))
|
||
|
||
def sugar_binop_visitor(self, _, children, op):
|
||
phi1, _, _, _, (phi2, ) = children
|
||
return op(phi1, phi2)
|
||
|
||
visit_f = partialmethod(unary_temp_op_visitor, op=ast.F)
|
||
visit_g = partialmethod(unary_temp_op_visitor, op=ast.G)
|
||
visit_or = partialmethod(binop_visitor, op=ast.Or)
|
||
visit_and = partialmethod(binop_visitor, op=ast.And)
|
||
visit_xor = partialmethod(sugar_binop_visitor, op=xor)
|
||
visit_iff = partialmethod(sugar_binop_visitor, op=iff)
|
||
visit_implies = partialmethod(sugar_binop_visitor, op=implies)
|
||
|
||
def visit_until(self, _, children):
|
||
phi1, _, _, _, phi2 = children
|
||
return ast.Until(phi1, phi2)
|
||
|
||
def visit_timed_until(self, _, children):
|
||
phi, _, _, (lo, hi), _, psi = children
|
||
return env(psi, lo=lo, hi=hi) & alw(ast.Until(phi, psi), lo=0, hi=lo)
|
||
|
||
def visit_id(self, name, _):
|
||
return Symbol(name.text)
|
||
|
||
def visit_const(self, const, children):
|
||
return float(const.text)
|
||
|
||
def visit_term(self, _, children):
|
||
coeffs, (iden, time) = children
|
||
c = coeffs[0] if coeffs else Number(1)
|
||
return ast.Var(coeff=c, id=iden, time=time)
|
||
|
||
def visit_coeff(self, _, children):
|
||
dt, coeff, *_ = children[0]
|
||
if not isinstance(dt, Symbol):
|
||
dt = dt[0][0] if dt else Number(1)
|
||
return dt * coeff
|
||
else:
|
||
return dt
|
||
|
||
def visit_terms(self, _, children):
|
||
if isinstance(children[0], list):
|
||
term, _1, sgn, _2, terms = children[0]
|
||
terms = lens(terms)[0].coeff * sgn
|
||
return [term] + terms
|
||
else:
|
||
return children
|
||
|
||
def visit_lineq(self, _, children):
|
||
terms, _1, op, _2, const = children
|
||
return ast.LinEq(tuple(terms), op, const[0])
|
||
|
||
def visit_pm(self, node, _):
|
||
return Number(1) if node.text == "+" else Number(-1)
|
||
|
||
def visit_AP(self, *args):
|
||
return ast.AtomicPred(self.visit_id(*args))
|
||
|
||
def visit_neg(self, _, children):
|
||
return ast.Neg(children[1])
|
||
|
||
def visit_next(self, _, children):
|
||
return ast.Next(children[1])
|
||
|
||
|
||
def parse(stl_str: str, rule: str = "phi", H=float('inf')) -> "STL":
|
||
return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str))
|