reimplement terms lens to work around _tuple bug + start milp

This commit is contained in:
Marcell Vazquez-Chanlatte 2016-07-10 12:06:49 -07:00
parent 2e3ff68be6
commit bed8c51756
3 changed files with 32 additions and 36 deletions

View file

@ -1,4 +1,4 @@
from stl.stl import time_lens, walk, tree from stl.stl import terms_lens, walk, tree
from stl.stl import dt_sym, t_sym from stl.stl import dt_sym, t_sym
from stl.stl import LinEq, Var, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg from stl.stl import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg
from stl.parser import parse from stl.parser import parse

View file

@ -17,7 +17,7 @@ from funcy import flatten
import numpy as np import numpy as np
from lenses import lens from lenses import lens
from sympy import Symbol from sympy import Function
from stl import stl from stl import stl
@ -45,13 +45,14 @@ terms = (term __ pm __ terms) / term
var = id time? var = id time?
time = prime / time_index time = prime / time_index
time_index = "[" "t" __ pm __ const "]" time_index = "(" "t" __ pm __ const ")"
prime = "'" prime = "'"
pm = "+" / "-" pm = "+" / "-"
dt = "dt" dt = "dt"
unbound = "?" unbound = "?"
id = ("x" / "u" / "w") ~r"[a-zA-z\d]*" id = ("x" / "u" / "w") (aZ / ~r"\d")*
aZ = (~r"[a-z]" / ~r"A-z")
const = ~r"[\+\-]?\d*(\.\d+)?" const = ~r"[\+\-]?\d*(\.\d+)?"
op = ">=" / "<=" / "<" / ">" / "=" op = ">=" / "<=" / "<" / ">" / "="
_ = ~r"\s"+ _ = ~r"\s"+
@ -96,18 +97,17 @@ class STLVisitor(NodeVisitor):
visit_and = partialmethod(binop_visitor, op=stl.And) visit_and = partialmethod(binop_visitor, op=stl.And)
def visit_id(self, name, _): def visit_id(self, name, _):
return Symbol(name.text) return Function(name.text)(stl.t_sym)
def visit_var(self, _, children): def visit_var(self, _, children):
iden, time_node = children iden, time_node = children
time_node = list(flatten(time_node)) time_node = list(flatten(time_node))
time = time_node[0] if len(time_node) > 0 else stl.t_sym time = time_node[0] if len(time_node) > 0 else stl.t_sym
return iden.subs(stl.t_sym, time)
return stl.Var(iden, time)
def visit_time_index(self, _, children): def visit_time_index(self, _, children):
return children[3]* children[5] return stl.t_sym + children[3]* children[5]
def visit_prime(self, *_): def visit_prime(self, *_):
return -stl.dt_sym return -stl.dt_sym
@ -121,7 +121,7 @@ class STLVisitor(NodeVisitor):
def visit_term(self, _, children): def visit_term(self, _, children):
coeffs, var = children coeffs, var = children
c = coeffs[0] if coeffs else 1 c = coeffs[0] if coeffs else 1
return lens(var).id*c return var*c
def visit_coeff(self, _, children): def visit_coeff(self, _, children):
dt, coeff, *_ = children dt, coeff, *_ = children
@ -131,14 +131,14 @@ class STLVisitor(NodeVisitor):
def visit_terms(self, _, children): def visit_terms(self, _, children):
if isinstance(children[0], list): if isinstance(children[0], list):
term, _1, sgn ,_2, terms = children[0] term, _1, sgn ,_2, terms = children[0]
terms = lens(terms)[0].id * sgn terms = lens(terms)[0]*sgn
return [term] + terms return [term] + terms
else: else:
return children return children
def visit_lineq(self, _, children): def visit_lineq(self, _, children):
terms, _1, op, _2, const = children terms, _1, op, _2, const = children
return stl.LinEq(tuple(terms), op, const[0]) return stl.LinEq(sum(terms), op, const[0])
def visit_pm(self, node, _): def visit_pm(self, node, _):
return 1 if node.text == "+" else -1 return 1 if node.text == "+" else -1

44
stl.py
View file

@ -7,6 +7,7 @@ from itertools import repeat
from typing import Union from typing import Union
from enum import Enum from enum import Enum
from sympy import Symbol from sympy import Symbol
import funcy as fn
from lenses import lens from lenses import lens
@ -17,23 +18,13 @@ t_sym = Symbol('t', positive=True)
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])): class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
def __repr__(self): def __repr__(self):
n = len(self.terms) rep = "{lhs} {op} {c}"
rep = "{}" return rep.format(lhs=self.terms, op=self.op, c=self.const)
if n > 1:
rep += " + {}"*(n - 1)
rep += " {op} {c}"
return rep.format(*self.terms, op=self.op, c=self.const)
def children(self): def children(self):
return [] return []
class Var(namedtuple("Var", ["id", "time"])):
def __repr__(self):
time_str = "[{}]".format(self.time)
return "{i}{t}".format(i=self.id, t=time_str)
class Interval(namedtuple('I', ['lower', 'upper'])): class Interval(namedtuple('I', ['lower', 'upper'])):
def __repr__(self): def __repr__(self):
return "[{},{}]".format(self.lower, self.upper) return "[{},{}]".format(self.lower, self.upper)
@ -102,18 +93,23 @@ def tree(stl):
return {x:set(x.children()) for x in walk(stl) if x.children()} return {x:set(x.children()) for x in walk(stl) if x.children()}
def time_lens(phi:"STL", bind=True) -> lens: def terms_lens(phi:"STL", bind=True) -> lens:
l = _time_lens(phi) tls = list(fn.flatten(_terms_lens(phi)))
return l.bind(phi) if bind else l tl = lens().tuple_(*tls).each_()
return tl.bind(phi) if bind else tl
def _time_lens(phi): def _child_lens(psi, focus):
if isinstance(phi, LinEq): if isinstance(psi, NaryOpSTL):
return lens().terms.each_().time for j, _ in enumerate(psi.args):
yield focus.args[j]
if isinstance(phi, NaryOpSTL):
child_lens = [lens()[i].add_lens(_time_lens(c)) for i, c
in enumerate(phi.children())]
return lens().args.tuple_(*child_lens).each_()
else: else:
return lens().arg.add_lens(_time_lens(phi.arg)) yield focus.arg
def _terms_lens(phi, focus=lens()):
psi = focus.get(state=phi)
if isinstance(psi, LinEq):
return [focus.terms]
child_lenses = list(_child_lens(psi, focus=focus))
return [_terms_lens(phi, focus=cl) for cl in child_lenses]