implemented context inlining + fix parsing errors

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-04-20 17:16:15 -07:00
parent 0d4a703dcc
commit 476917860e
8 changed files with 40 additions and 13 deletions

View file

@ -61,7 +61,7 @@ BOT = _Bot()
class AtomicPred(namedtuple("AP", ["id", "time"]), AST): class AtomicPred(namedtuple("AP", ["id", "time"]), AST):
def __repr__(self): def __repr__(self):
return f"{self.id}[{self.time}]" return f"{self.id}({self.time})"
def children(self): def children(self):
return [] return []
@ -81,7 +81,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
class Var(namedtuple("Var", ["coeff", "id", "time"])): class Var(namedtuple("Var", ["coeff", "id", "time"])):
def __repr__(self): def __repr__(self):
return f"{self.coeff}*{self.id}[{self.time}]" return f"{self.coeff}*{self.id}({self.time})"
class Interval(namedtuple('I', ['lower', 'upper'])): class Interval(namedtuple('I', ['lower', 'upper'])):

View file

@ -54,7 +54,7 @@ op_lookup = {
@pointwise_sat.register(stl.AtomicPred) @pointwise_sat.register(stl.AtomicPred)
def _(stl): def _(stl):
return lambda x, t: x[stl.id][t] return lambda x, t: x[str(stl.id)][t]
@pointwise_sat.register(stl.LinEq) @pointwise_sat.register(stl.LinEq)

View file

@ -68,7 +68,7 @@ def _(stl):
def _(stl): def _(stl):
def sat_comp(x, t): def sat_comp(x, t):
sat = bitarray() sat = bitarray()
[sat.append(x[stl.id][tau]) for tau in t] [sat.append(x[str(stl.id)][tau]) for tau in t]
return sat return sat
return sat_comp return sat_comp

View file

@ -21,7 +21,7 @@ from stl import ast
from stl.utils import implies, xor, iff from stl.utils import implies, xor, iff
STL_GRAMMAR = Grammar(u''' STL_GRAMMAR = Grammar(u'''
phi = (neg / g / f / until / lineq / AP / or / and / implies / xor / iff / paren_phi) phi = (until / neg / g / f / lineq / AP / or / and / implies / xor / iff / paren_phi)
paren_phi = "(" __ phi __ ")" paren_phi = "(" __ phi __ ")"
@ -35,7 +35,7 @@ neg = ("~" / "¬") phi
f = F interval? phi f = F interval? phi
g = G interval? phi g = G interval? phi
until = "(" __ phi _ U interval? _ phi __ ")" until = paren_phi __ U interval? __ paren_phi
F = "F" / "" F = "F" / ""
G = "G" / "" G = "G" / ""
@ -52,10 +52,10 @@ terms = (term __ pm __ terms) / term
var = id time? var = id time?
time = prime / time_index time = prime / time_index
time_index = "[" "t" __ pm __ const "]" time_index = "(" ("t" / const) ")"
prime = "'" prime = "'"
AP = ~r"[a-zA-z\d]+" AP = id time?
pm = "+" / "-" pm = "+" / "-"
dt = "dt" dt = "dt"
@ -118,7 +118,7 @@ class STLVisitor(NodeVisitor):
visit_implies = partialmethod(sugar_binop_visitor, op=implies) visit_implies = partialmethod(sugar_binop_visitor, op=implies)
def visit_until(self, _, children): def visit_until(self, _, children):
_, _, phi1, _, _, i, _, phi2, *_ = children phi1, _, _, i, _, phi2 = children
i = self.default_interval if not i else i[0] i = self.default_interval if not i else i[0]
return ast.Until(i, phi1, phi2) return ast.Until(i, phi1, phi2)
@ -134,7 +134,8 @@ class STLVisitor(NodeVisitor):
return iden, time return iden, time
def visit_time_index(self, _, children): def visit_time_index(self, _, children):
return children[3]* children[5] children = list(flatten(children))
return children[0] if children else ast.t_sym
def visit_prime(self, *_): def visit_prime(self, *_):
return ast.t_sym + ast.dt_sym return ast.t_sym + ast.dt_sym
@ -173,8 +174,8 @@ class STLVisitor(NodeVisitor):
def visit_pm(self, node, _): def visit_pm(self, node, _):
return Number(1) if node.text == "+" else Number(-1) return Number(1) if node.text == "+" else Number(-1)
def visit_AP(self, node, _): def visit_AP(self, *args):
return ast.AtomicPred(node.text, ast.t_sym) return ast.AtomicPred(*self.visit_var(*args))
def visit_neg(self, _, children): def visit_neg(self, _, children):
return ast.Neg(children[1]) return ast.Neg(children[1])

View file

@ -28,6 +28,7 @@ class TestSTLEval(unittest.TestCase):
self.assertEqual(stl_eval2(x, 0), not r) self.assertEqual(stl_eval2(x, 0), not r)
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9) @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
def test_fasteval(self, phi_str, _): def test_fasteval(self, phi_str, _):
phi = stl.parse(phi_str) phi = stl.parse(phi_str)

View file

@ -15,7 +15,7 @@ ex1_ = ('x1 > a?', stl.LinEq(
Symbol("a?") Symbol("a?")
)) ))
ex1__ = ('x1', stl.AtomicPred('x1', stl.t_sym)) ex1__ = ('x1', stl.AtomicPred(Symbol('x1'), stl.t_sym))
i1 = stl.Interval(0., 1.) i1 = stl.Interval(0., 1.)
i1_ = stl.Interval(0., Symbol("b?")) i1_ = stl.Interval(0., Symbol("b?"))

View file

@ -67,6 +67,23 @@ class TestSTLUtils(unittest.TestCase):
phi = stl.parse("x") phi = stl.parse("x")
self.assertEqual(phi, stl.orf(phi)) self.assertEqual(phi, stl.orf(phi))
def test_inline_context(self):
context = {
stl.parse("x"): stl.parse("(z) & (y)"),
stl.parse("z"): stl.parse("y - x > 4")
}
context2 = {
stl.parse("x"): stl.parse("x"),
}
phi = stl.parse("x")
self.assertEqual(stl.utils.inline_context(phi, {}), phi)
self.assertEqual(stl.utils.inline_context(phi, context),
stl.parse("(y - x > 4) & (y)"))
phi2 = stl.parse("((x) & (z)) | (y)")
self.assertEqual(stl.utils.inline_context(phi2, context),
stl.parse("((y - x > 4) & (y) & (y - x > 4)) | (y)"))
# def test_to_from_mtl(self): # def test_to_from_mtl(self):
# raise NotImplementedError # raise NotImplementedError

View file

@ -128,6 +128,14 @@ def linear_stl_lipschitz(phi):
"""Infinity norm lipschitz bound of linear inequality predicate.""" """Infinity norm lipschitz bound of linear inequality predicate."""
return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all()))) return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all())))
def inline_context(phi, context):
phi2 = None
update = lambda ap: context.get(ap, ap)
while phi2 != phi:
phi2, phi = phi, AP_lens(phi).modify(update)
# TODO: this is hack to flatten the AST. Fix!
return stl.parse(str(phi))
# EDSL # EDSL
def alw(phi, *, lo, hi): def alw(phi, *, lo, hi):