implemented context inlining + fix parsing errors

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-04-20 17:16:15 -07:00
parent 0d4a703dcc
commit 476917860e
8 changed files with 40 additions and 13 deletions

View file

@ -61,7 +61,7 @@ BOT = _Bot()
class AtomicPred(namedtuple("AP", ["id", "time"]), AST):
def __repr__(self):
return f"{self.id}[{self.time}]"
return f"{self.id}({self.time})"
def children(self):
return []
@ -81,7 +81,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
class Var(namedtuple("Var", ["coeff", "id", "time"])):
def __repr__(self):
return f"{self.coeff}*{self.id}[{self.time}]"
return f"{self.coeff}*{self.id}({self.time})"
class Interval(namedtuple('I', ['lower', 'upper'])):

View file

@ -54,7 +54,7 @@ op_lookup = {
@pointwise_sat.register(stl.AtomicPred)
def _(stl):
return lambda x, t: x[stl.id][t]
return lambda x, t: x[str(stl.id)][t]
@pointwise_sat.register(stl.LinEq)

View file

@ -68,7 +68,7 @@ def _(stl):
def _(stl):
def sat_comp(x, t):
sat = bitarray()
[sat.append(x[stl.id][tau]) for tau in t]
[sat.append(x[str(stl.id)][tau]) for tau in t]
return sat
return sat_comp

View file

@ -21,7 +21,7 @@ from stl import ast
from stl.utils import implies, xor, iff
STL_GRAMMAR = Grammar(u'''
phi = (neg / g / f / until / lineq / AP / or / and / implies / xor / iff / paren_phi)
phi = (until / neg / g / f / lineq / AP / or / and / implies / xor / iff / paren_phi)
paren_phi = "(" __ phi __ ")"
@ -35,7 +35,7 @@ neg = ("~" / "¬") phi
f = F interval? phi
g = G interval? phi
until = "(" __ phi _ U interval? _ phi __ ")"
until = paren_phi __ U interval? __ paren_phi
F = "F" / ""
G = "G" / ""
@ -52,10 +52,10 @@ terms = (term __ pm __ terms) / term
var = id time?
time = prime / time_index
time_index = "[" "t" __ pm __ const "]"
time_index = "(" ("t" / const) ")"
prime = "'"
AP = ~r"[a-zA-z\d]+"
AP = id time?
pm = "+" / "-"
dt = "dt"
@ -118,7 +118,7 @@ class STLVisitor(NodeVisitor):
visit_implies = partialmethod(sugar_binop_visitor, op=implies)
def visit_until(self, _, children):
_, _, phi1, _, _, i, _, phi2, *_ = children
phi1, _, _, i, _, phi2 = children
i = self.default_interval if not i else i[0]
return ast.Until(i, phi1, phi2)
@ -134,7 +134,8 @@ class STLVisitor(NodeVisitor):
return iden, time
def visit_time_index(self, _, children):
return children[3]* children[5]
children = list(flatten(children))
return children[0] if children else ast.t_sym
def visit_prime(self, *_):
return ast.t_sym + ast.dt_sym
@ -173,8 +174,8 @@ class STLVisitor(NodeVisitor):
def visit_pm(self, node, _):
return Number(1) if node.text == "+" else Number(-1)
def visit_AP(self, node, _):
return ast.AtomicPred(node.text, ast.t_sym)
def visit_AP(self, *args):
return ast.AtomicPred(*self.visit_var(*args))
def visit_neg(self, _, children):
return ast.Neg(children[1])

View file

@ -28,6 +28,7 @@ class TestSTLEval(unittest.TestCase):
self.assertEqual(stl_eval2(x, 0), not r)
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
def test_fasteval(self, phi_str, _):
phi = stl.parse(phi_str)

View file

@ -15,7 +15,7 @@ ex1_ = ('x1 > a?', stl.LinEq(
Symbol("a?")
))
ex1__ = ('x1', stl.AtomicPred('x1', stl.t_sym))
ex1__ = ('x1', stl.AtomicPred(Symbol('x1'), stl.t_sym))
i1 = stl.Interval(0., 1.)
i1_ = stl.Interval(0., Symbol("b?"))

View file

@ -67,6 +67,23 @@ class TestSTLUtils(unittest.TestCase):
phi = stl.parse("x")
self.assertEqual(phi, stl.orf(phi))
def test_inline_context(self):
context = {
stl.parse("x"): stl.parse("(z) & (y)"),
stl.parse("z"): stl.parse("y - x > 4")
}
context2 = {
stl.parse("x"): stl.parse("x"),
}
phi = stl.parse("x")
self.assertEqual(stl.utils.inline_context(phi, {}), phi)
self.assertEqual(stl.utils.inline_context(phi, context),
stl.parse("(y - x > 4) & (y)"))
phi2 = stl.parse("((x) & (z)) | (y)")
self.assertEqual(stl.utils.inline_context(phi2, context),
stl.parse("((y - x > 4) & (y) & (y - x > 4)) | (y)"))
# def test_to_from_mtl(self):
# raise NotImplementedError

View file

@ -128,6 +128,14 @@ def linear_stl_lipschitz(phi):
"""Infinity norm lipschitz bound of linear inequality predicate."""
return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all())))
def inline_context(phi, context):
phi2 = None
update = lambda ap: context.get(ap, ap)
while phi2 != phi:
phi2, phi = phi, AP_lens(phi).modify(update)
# TODO: this is hack to flatten the AST. Fix!
return stl.parse(str(phi))
# EDSL
def alw(phi, *, lo, hi):