diff --git a/__init__.py b/__init__.py index c70b1db..b89fbdf 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from stl.stl import terms_lens, walk, tree +from stl.stl import terms_lens, lineq_lens, walk, tree from stl.stl import dt_sym, t_sym from stl.stl import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg from stl.parser import parse diff --git a/parser.py b/parser.py index f35f7c1..79ec1da 100644 --- a/parser.py +++ b/parser.py @@ -17,7 +17,7 @@ from funcy import flatten import numpy as np from lenses import lens -from sympy import Function +from sympy import Symbol, Number from stl import stl @@ -45,14 +45,13 @@ terms = (term __ pm __ terms) / term var = id time? time = prime / time_index -time_index = "(" "t" __ pm __ const ")" +time_index = "[" "t" __ pm __ const "]" prime = "'" pm = "+" / "-" dt = "dt" unbound = "?" -id = ("x" / "u" / "w") (aZ / ~r"\d")* -aZ = (~r"[a-z]" / ~r"A-z") +id = ("x" / "u" / "w") ~r"[a-zA-z\d]*" const = ~r"[\+\-]?\d*(\.\d+)?" op = ">=" / "<=" / "<" / ">" / "=" _ = ~r"\s"+ @@ -97,20 +96,21 @@ class STLVisitor(NodeVisitor): visit_and = partialmethod(binop_visitor, op=stl.And) def visit_id(self, name, _): - return Function(name.text)(stl.t_sym) + return Symbol(name.text) def visit_var(self, _, children): iden, time_node = children time_node = list(flatten(time_node)) time = time_node[0] if len(time_node) > 0 else stl.t_sym - return iden.subs(stl.t_sym, time) + + return iden, time def visit_time_index(self, _, children): - return stl.t_sym + children[3]* children[5] + return children[3]* children[5] def visit_prime(self, *_): - return -stl.dt_sym + return stl.t_sym - stl.dt_sym def visit_const(self, const, children): return float(const.text) @@ -119,29 +119,30 @@ class STLVisitor(NodeVisitor): return stl.dt_sym def visit_term(self, _, children): - coeffs, var = children - c = coeffs[0] if coeffs else 1 - return var*c + coeffs, (iden, time) = children + c = coeffs[0] if coeffs else Number(1) + return stl.Var(coeff=c, id=iden, time=time) + def visit_coeff(self, _, children): dt, coeff, *_ = children - dt = dt[0][0] if dt else 1 + dt = dt[0][0] if dt else Number(1) return dt * coeff def visit_terms(self, _, children): if isinstance(children[0], list): term, _1, sgn ,_2, terms = children[0] - terms = lens(terms)[0]*sgn + terms = lens(terms)[0].coeff * sgn return [term] + terms else: return children def visit_lineq(self, _, children): terms, _1, op, _2, const = children - return stl.LinEq(sum(terms), op, const[0]) + return stl.LinEq(tuple(terms), op, const[0]) def visit_pm(self, node, _): - return 1 if node.text == "+" else -1 + return Number(1) if node.text == "+" else Number(-1) def parse(stl_str:str, rule:str="phi") -> "STL": diff --git a/stl.py b/stl.py index c3b3de6..1a7d98a 100644 --- a/stl.py +++ b/stl.py @@ -7,10 +7,11 @@ from itertools import repeat from typing import Union from enum import Enum from sympy import Symbol -import funcy as fn from lenses import lens +import funcy as fn + VarKind = Enum("VarKind", ["x", "u", "w"]) str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w} dt_sym = Symbol('dt', positive=True) @@ -18,13 +19,23 @@ t_sym = Symbol('t', positive=True) class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])): def __repr__(self): - rep = "{lhs} {op} {c}" - return rep.format(lhs=self.terms, op=self.op, c=self.const) + n = len(self.terms) + rep = "{}" + if n > 1: + rep += " + {}"*(n - 1) + rep += " {op} {c}" + return rep.format(*self.terms, op=self.op, c=self.const) def children(self): return [] +class Var(namedtuple("Var", ["coeff", "id", "time"])): + def __repr__(self): + time_str = "[{}]".format(self.time) + return "{c}*{i}{t}".format(c=self.coeff, i=self.id, t=time_str) + + class Interval(namedtuple('I', ['lower', 'upper'])): def __repr__(self): return "[{},{}]".format(self.lower, self.upper) @@ -93,12 +104,16 @@ def tree(stl): return {x:set(x.children()) for x in walk(stl) if x.children()} -def terms_lens(phi:"STL", bind=True) -> lens: - tls = list(fn.flatten(_terms_lens(phi))) +def lineq_lens(phi:"STL", bind=True) -> lens: + tls = list(fn.flatten(_lineq_lens(phi))) tl = lens().tuple_(*tls).each_() return tl.bind(phi) if bind else tl +def terms_lens(phi:"STL", bind=True) -> lens: + return lineq_lens(phi, bind).terms.each_() + + def _child_lens(psi, focus): if isinstance(psi, NaryOpSTL): for j, _ in enumerate(psi.args): @@ -107,9 +122,9 @@ def _child_lens(psi, focus): yield focus.arg -def _terms_lens(phi, focus=lens()): +def _lineq_lens(phi, focus=lens()): psi = focus.get(state=phi) if isinstance(psi, LinEq): - return [focus.terms] + return [focus] child_lenses = list(_child_lens(psi, focus=focus)) - return [_terms_lens(phi, focus=cl) for cl in child_lenses] + return [_lineq_lens(phi, focus=cl) for cl in child_lenses]