diff --git a/__init__.py b/__init__.py index 99d5389..c70b1db 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from stl.stl import time_lens, walk, tree +from stl.stl import terms_lens, walk, tree from stl.stl import dt_sym, t_sym -from stl.stl import LinEq, Var, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg +from stl.stl import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg from stl.parser import parse diff --git a/parser.py b/parser.py index dd204a4..f35f7c1 100644 --- a/parser.py +++ b/parser.py @@ -17,7 +17,7 @@ from funcy import flatten import numpy as np from lenses import lens -from sympy import Symbol +from sympy import Function from stl import stl @@ -45,13 +45,14 @@ terms = (term __ pm __ terms) / term var = id time? time = prime / time_index -time_index = "[" "t" __ pm __ const "]" +time_index = "(" "t" __ pm __ const ")" prime = "'" pm = "+" / "-" dt = "dt" unbound = "?" -id = ("x" / "u" / "w") ~r"[a-zA-z\d]*" +id = ("x" / "u" / "w") (aZ / ~r"\d")* +aZ = (~r"[a-z]" / ~r"A-z") const = ~r"[\+\-]?\d*(\.\d+)?" op = ">=" / "<=" / "<" / ">" / "=" _ = ~r"\s"+ @@ -96,18 +97,17 @@ class STLVisitor(NodeVisitor): visit_and = partialmethod(binop_visitor, op=stl.And) def visit_id(self, name, _): - return Symbol(name.text) + return Function(name.text)(stl.t_sym) def visit_var(self, _, children): iden, time_node = children time_node = list(flatten(time_node)) time = time_node[0] if len(time_node) > 0 else stl.t_sym - - return stl.Var(iden, time) + return iden.subs(stl.t_sym, time) def visit_time_index(self, _, children): - return children[3]* children[5] + return stl.t_sym + children[3]* children[5] def visit_prime(self, *_): return -stl.dt_sym @@ -121,7 +121,7 @@ class STLVisitor(NodeVisitor): def visit_term(self, _, children): coeffs, var = children c = coeffs[0] if coeffs else 1 - return lens(var).id*c + return var*c def visit_coeff(self, _, children): dt, coeff, *_ = children @@ -131,14 +131,14 @@ class STLVisitor(NodeVisitor): def visit_terms(self, _, children): if isinstance(children[0], list): term, _1, sgn ,_2, terms = children[0] - terms = lens(terms)[0].id * sgn + terms = lens(terms)[0]*sgn return [term] + terms else: return children def visit_lineq(self, _, children): terms, _1, op, _2, const = children - return stl.LinEq(tuple(terms), op, const[0]) + return stl.LinEq(sum(terms), op, const[0]) def visit_pm(self, node, _): return 1 if node.text == "+" else -1 diff --git a/stl.py b/stl.py index 5adf66b..c3b3de6 100644 --- a/stl.py +++ b/stl.py @@ -7,6 +7,7 @@ from itertools import repeat from typing import Union from enum import Enum from sympy import Symbol +import funcy as fn from lenses import lens @@ -17,23 +18,13 @@ t_sym = Symbol('t', positive=True) class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])): def __repr__(self): - n = len(self.terms) - rep = "{}" - if n > 1: - rep += " + {}"*(n - 1) - rep += " {op} {c}" - return rep.format(*self.terms, op=self.op, c=self.const) + rep = "{lhs} {op} {c}" + return rep.format(lhs=self.terms, op=self.op, c=self.const) def children(self): return [] -class Var(namedtuple("Var", ["id", "time"])): - def __repr__(self): - time_str = "[{}]".format(self.time) - return "{i}{t}".format(i=self.id, t=time_str) - - class Interval(namedtuple('I', ['lower', 'upper'])): def __repr__(self): return "[{},{}]".format(self.lower, self.upper) @@ -102,18 +93,23 @@ def tree(stl): return {x:set(x.children()) for x in walk(stl) if x.children()} -def time_lens(phi:"STL", bind=True) -> lens: - l = _time_lens(phi) - return l.bind(phi) if bind else l +def terms_lens(phi:"STL", bind=True) -> lens: + tls = list(fn.flatten(_terms_lens(phi))) + tl = lens().tuple_(*tls).each_() + return tl.bind(phi) if bind else tl -def _time_lens(phi): - if isinstance(phi, LinEq): - return lens().terms.each_().time - - if isinstance(phi, NaryOpSTL): - child_lens = [lens()[i].add_lens(_time_lens(c)) for i, c - in enumerate(phi.children())] - return lens().args.tuple_(*child_lens).each_() +def _child_lens(psi, focus): + if isinstance(psi, NaryOpSTL): + for j, _ in enumerate(psi.args): + yield focus.args[j] else: - return lens().arg.add_lens(_time_lens(phi.arg)) + yield focus.arg + + +def _terms_lens(phi, focus=lens()): + psi = focus.get(state=phi) + if isinstance(psi, LinEq): + return [focus.terms] + child_lenses = list(_child_lens(psi, focus=focus)) + return [_terms_lens(phi, focus=cl) for cl in child_lenses]