reimplement terms lens to work around _tuple bug + start milp

This commit is contained in:
Marcell Vazquez-Chanlatte 2016-07-10 12:06:49 -07:00
parent 2e3ff68be6
commit bed8c51756
3 changed files with 32 additions and 36 deletions

View file

@ -1,4 +1,4 @@
from stl.stl import time_lens, walk, tree
from stl.stl import terms_lens, walk, tree
from stl.stl import dt_sym, t_sym
from stl.stl import LinEq, Var, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg
from stl.stl import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg
from stl.parser import parse

View file

@ -17,7 +17,7 @@ from funcy import flatten
import numpy as np
from lenses import lens
from sympy import Symbol
from sympy import Function
from stl import stl
@ -45,13 +45,14 @@ terms = (term __ pm __ terms) / term
var = id time?
time = prime / time_index
time_index = "[" "t" __ pm __ const "]"
time_index = "(" "t" __ pm __ const ")"
prime = "'"
pm = "+" / "-"
dt = "dt"
unbound = "?"
id = ("x" / "u" / "w") ~r"[a-zA-z\d]*"
id = ("x" / "u" / "w") (aZ / ~r"\d")*
aZ = (~r"[a-z]" / ~r"A-z")
const = ~r"[\+\-]?\d*(\.\d+)?"
op = ">=" / "<=" / "<" / ">" / "="
_ = ~r"\s"+
@ -96,18 +97,17 @@ class STLVisitor(NodeVisitor):
visit_and = partialmethod(binop_visitor, op=stl.And)
def visit_id(self, name, _):
return Symbol(name.text)
return Function(name.text)(stl.t_sym)
def visit_var(self, _, children):
iden, time_node = children
time_node = list(flatten(time_node))
time = time_node[0] if len(time_node) > 0 else stl.t_sym
return stl.Var(iden, time)
return iden.subs(stl.t_sym, time)
def visit_time_index(self, _, children):
return children[3]* children[5]
return stl.t_sym + children[3]* children[5]
def visit_prime(self, *_):
return -stl.dt_sym
@ -121,7 +121,7 @@ class STLVisitor(NodeVisitor):
def visit_term(self, _, children):
coeffs, var = children
c = coeffs[0] if coeffs else 1
return lens(var).id*c
return var*c
def visit_coeff(self, _, children):
dt, coeff, *_ = children
@ -131,14 +131,14 @@ class STLVisitor(NodeVisitor):
def visit_terms(self, _, children):
if isinstance(children[0], list):
term, _1, sgn ,_2, terms = children[0]
terms = lens(terms)[0].id * sgn
terms = lens(terms)[0]*sgn
return [term] + terms
else:
return children
def visit_lineq(self, _, children):
terms, _1, op, _2, const = children
return stl.LinEq(tuple(terms), op, const[0])
return stl.LinEq(sum(terms), op, const[0])
def visit_pm(self, node, _):
return 1 if node.text == "+" else -1

44
stl.py
View file

@ -7,6 +7,7 @@ from itertools import repeat
from typing import Union
from enum import Enum
from sympy import Symbol
import funcy as fn
from lenses import lens
@ -17,23 +18,13 @@ t_sym = Symbol('t', positive=True)
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
def __repr__(self):
n = len(self.terms)
rep = "{}"
if n > 1:
rep += " + {}"*(n - 1)
rep += " {op} {c}"
return rep.format(*self.terms, op=self.op, c=self.const)
rep = "{lhs} {op} {c}"
return rep.format(lhs=self.terms, op=self.op, c=self.const)
def children(self):
return []
class Var(namedtuple("Var", ["id", "time"])):
def __repr__(self):
time_str = "[{}]".format(self.time)
return "{i}{t}".format(i=self.id, t=time_str)
class Interval(namedtuple('I', ['lower', 'upper'])):
def __repr__(self):
return "[{},{}]".format(self.lower, self.upper)
@ -102,18 +93,23 @@ def tree(stl):
return {x:set(x.children()) for x in walk(stl) if x.children()}
def time_lens(phi:"STL", bind=True) -> lens:
l = _time_lens(phi)
return l.bind(phi) if bind else l
def terms_lens(phi:"STL", bind=True) -> lens:
tls = list(fn.flatten(_terms_lens(phi)))
tl = lens().tuple_(*tls).each_()
return tl.bind(phi) if bind else tl
def _time_lens(phi):
if isinstance(phi, LinEq):
return lens().terms.each_().time
if isinstance(phi, NaryOpSTL):
child_lens = [lens()[i].add_lens(_time_lens(c)) for i, c
in enumerate(phi.children())]
return lens().args.tuple_(*child_lens).each_()
def _child_lens(psi, focus):
if isinstance(psi, NaryOpSTL):
for j, _ in enumerate(psi.args):
yield focus.args[j]
else:
return lens().arg.add_lens(_time_lens(phi.arg))
yield focus.arg
def _terms_lens(phi, focus=lens()):
psi = focus.get(state=phi)
if isinstance(psi, LinEq):
return [focus.terms]
child_lenses = list(_child_lens(psi, focus=focus))
return [_terms_lens(phi, focus=cl) for cl in child_lenses]