reverted back to nametuple repr + time filters

This commit is contained in:
Marcell Vazquez-Chanlatte 2016-07-10 16:36:53 -07:00
parent bed8c51756
commit c4cfda6e19
3 changed files with 40 additions and 24 deletions

View file

@ -1,4 +1,4 @@
from stl.stl import terms_lens, walk, tree from stl.stl import terms_lens, lineq_lens, walk, tree
from stl.stl import dt_sym, t_sym from stl.stl import dt_sym, t_sym
from stl.stl import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg from stl.stl import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg
from stl.parser import parse from stl.parser import parse

View file

@ -17,7 +17,7 @@ from funcy import flatten
import numpy as np import numpy as np
from lenses import lens from lenses import lens
from sympy import Function from sympy import Symbol, Number
from stl import stl from stl import stl
@ -45,14 +45,13 @@ terms = (term __ pm __ terms) / term
var = id time? var = id time?
time = prime / time_index time = prime / time_index
time_index = "(" "t" __ pm __ const ")" time_index = "[" "t" __ pm __ const "]"
prime = "'" prime = "'"
pm = "+" / "-" pm = "+" / "-"
dt = "dt" dt = "dt"
unbound = "?" unbound = "?"
id = ("x" / "u" / "w") (aZ / ~r"\d")* id = ("x" / "u" / "w") ~r"[a-zA-z\d]*"
aZ = (~r"[a-z]" / ~r"A-z")
const = ~r"[\+\-]?\d*(\.\d+)?" const = ~r"[\+\-]?\d*(\.\d+)?"
op = ">=" / "<=" / "<" / ">" / "=" op = ">=" / "<=" / "<" / ">" / "="
_ = ~r"\s"+ _ = ~r"\s"+
@ -97,20 +96,21 @@ class STLVisitor(NodeVisitor):
visit_and = partialmethod(binop_visitor, op=stl.And) visit_and = partialmethod(binop_visitor, op=stl.And)
def visit_id(self, name, _): def visit_id(self, name, _):
return Function(name.text)(stl.t_sym) return Symbol(name.text)
def visit_var(self, _, children): def visit_var(self, _, children):
iden, time_node = children iden, time_node = children
time_node = list(flatten(time_node)) time_node = list(flatten(time_node))
time = time_node[0] if len(time_node) > 0 else stl.t_sym time = time_node[0] if len(time_node) > 0 else stl.t_sym
return iden.subs(stl.t_sym, time)
return iden, time
def visit_time_index(self, _, children): def visit_time_index(self, _, children):
return stl.t_sym + children[3]* children[5] return children[3]* children[5]
def visit_prime(self, *_): def visit_prime(self, *_):
return -stl.dt_sym return stl.t_sym - stl.dt_sym
def visit_const(self, const, children): def visit_const(self, const, children):
return float(const.text) return float(const.text)
@ -119,29 +119,30 @@ class STLVisitor(NodeVisitor):
return stl.dt_sym return stl.dt_sym
def visit_term(self, _, children): def visit_term(self, _, children):
coeffs, var = children coeffs, (iden, time) = children
c = coeffs[0] if coeffs else 1 c = coeffs[0] if coeffs else Number(1)
return var*c return stl.Var(coeff=c, id=iden, time=time)
def visit_coeff(self, _, children): def visit_coeff(self, _, children):
dt, coeff, *_ = children dt, coeff, *_ = children
dt = dt[0][0] if dt else 1 dt = dt[0][0] if dt else Number(1)
return dt * coeff return dt * coeff
def visit_terms(self, _, children): def visit_terms(self, _, children):
if isinstance(children[0], list): if isinstance(children[0], list):
term, _1, sgn ,_2, terms = children[0] term, _1, sgn ,_2, terms = children[0]
terms = lens(terms)[0]*sgn terms = lens(terms)[0].coeff * sgn
return [term] + terms return [term] + terms
else: else:
return children return children
def visit_lineq(self, _, children): def visit_lineq(self, _, children):
terms, _1, op, _2, const = children terms, _1, op, _2, const = children
return stl.LinEq(sum(terms), op, const[0]) return stl.LinEq(tuple(terms), op, const[0])
def visit_pm(self, node, _): def visit_pm(self, node, _):
return 1 if node.text == "+" else -1 return Number(1) if node.text == "+" else Number(-1)
def parse(stl_str:str, rule:str="phi") -> "STL": def parse(stl_str:str, rule:str="phi") -> "STL":

31
stl.py
View file

@ -7,10 +7,11 @@ from itertools import repeat
from typing import Union from typing import Union
from enum import Enum from enum import Enum
from sympy import Symbol from sympy import Symbol
import funcy as fn
from lenses import lens from lenses import lens
import funcy as fn
VarKind = Enum("VarKind", ["x", "u", "w"]) VarKind = Enum("VarKind", ["x", "u", "w"])
str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w} str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w}
dt_sym = Symbol('dt', positive=True) dt_sym = Symbol('dt', positive=True)
@ -18,13 +19,23 @@ t_sym = Symbol('t', positive=True)
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])): class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
def __repr__(self): def __repr__(self):
rep = "{lhs} {op} {c}" n = len(self.terms)
return rep.format(lhs=self.terms, op=self.op, c=self.const) rep = "{}"
if n > 1:
rep += " + {}"*(n - 1)
rep += " {op} {c}"
return rep.format(*self.terms, op=self.op, c=self.const)
def children(self): def children(self):
return [] return []
class Var(namedtuple("Var", ["coeff", "id", "time"])):
def __repr__(self):
time_str = "[{}]".format(self.time)
return "{c}*{i}{t}".format(c=self.coeff, i=self.id, t=time_str)
class Interval(namedtuple('I', ['lower', 'upper'])): class Interval(namedtuple('I', ['lower', 'upper'])):
def __repr__(self): def __repr__(self):
return "[{},{}]".format(self.lower, self.upper) return "[{},{}]".format(self.lower, self.upper)
@ -93,12 +104,16 @@ def tree(stl):
return {x:set(x.children()) for x in walk(stl) if x.children()} return {x:set(x.children()) for x in walk(stl) if x.children()}
def terms_lens(phi:"STL", bind=True) -> lens: def lineq_lens(phi:"STL", bind=True) -> lens:
tls = list(fn.flatten(_terms_lens(phi))) tls = list(fn.flatten(_lineq_lens(phi)))
tl = lens().tuple_(*tls).each_() tl = lens().tuple_(*tls).each_()
return tl.bind(phi) if bind else tl return tl.bind(phi) if bind else tl
def terms_lens(phi:"STL", bind=True) -> lens:
return lineq_lens(phi, bind).terms.each_()
def _child_lens(psi, focus): def _child_lens(psi, focus):
if isinstance(psi, NaryOpSTL): if isinstance(psi, NaryOpSTL):
for j, _ in enumerate(psi.args): for j, _ in enumerate(psi.args):
@ -107,9 +122,9 @@ def _child_lens(psi, focus):
yield focus.arg yield focus.arg
def _terms_lens(phi, focus=lens()): def _lineq_lens(phi, focus=lens()):
psi = focus.get(state=phi) psi = focus.get(state=phi)
if isinstance(psi, LinEq): if isinstance(psi, LinEq):
return [focus.terms] return [focus]
child_lenses = list(_child_lens(psi, focus=focus)) child_lenses = list(_child_lens(psi, focus=focus))
return [_terms_lens(phi, focus=cl) for cl in child_lenses] return [_lineq_lens(phi, focus=cl) for cl in child_lenses]