implemented context inlining + fix parsing errors
This commit is contained in:
parent
0d4a703dcc
commit
476917860e
8 changed files with 40 additions and 13 deletions
|
|
@ -61,7 +61,7 @@ BOT = _Bot()
|
|||
|
||||
class AtomicPred(namedtuple("AP", ["id", "time"]), AST):
|
||||
def __repr__(self):
|
||||
return f"{self.id}[{self.time}]"
|
||||
return f"{self.id}({self.time})"
|
||||
|
||||
def children(self):
|
||||
return []
|
||||
|
|
@ -81,7 +81,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
|
|||
|
||||
class Var(namedtuple("Var", ["coeff", "id", "time"])):
|
||||
def __repr__(self):
|
||||
return f"{self.coeff}*{self.id}[{self.time}]"
|
||||
return f"{self.coeff}*{self.id}({self.time})"
|
||||
|
||||
|
||||
class Interval(namedtuple('I', ['lower', 'upper'])):
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ op_lookup = {
|
|||
|
||||
@pointwise_sat.register(stl.AtomicPred)
|
||||
def _(stl):
|
||||
return lambda x, t: x[stl.id][t]
|
||||
return lambda x, t: x[str(stl.id)][t]
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.LinEq)
|
||||
|
|
|
|||
|
|
@ -68,7 +68,7 @@ def _(stl):
|
|||
def _(stl):
|
||||
def sat_comp(x, t):
|
||||
sat = bitarray()
|
||||
[sat.append(x[stl.id][tau]) for tau in t]
|
||||
[sat.append(x[str(stl.id)][tau]) for tau in t]
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from stl import ast
|
|||
from stl.utils import implies, xor, iff
|
||||
|
||||
STL_GRAMMAR = Grammar(u'''
|
||||
phi = (neg / g / f / until / lineq / AP / or / and / implies / xor / iff / paren_phi)
|
||||
phi = (until / neg / g / f / lineq / AP / or / and / implies / xor / iff / paren_phi)
|
||||
|
||||
paren_phi = "(" __ phi __ ")"
|
||||
|
||||
|
|
@ -35,7 +35,7 @@ neg = ("~" / "¬") phi
|
|||
|
||||
f = F interval? phi
|
||||
g = G interval? phi
|
||||
until = "(" __ phi _ U interval? _ phi __ ")"
|
||||
until = paren_phi __ U interval? __ paren_phi
|
||||
|
||||
F = "F" / "◇"
|
||||
G = "G" / "□"
|
||||
|
|
@ -52,10 +52,10 @@ terms = (term __ pm __ terms) / term
|
|||
|
||||
var = id time?
|
||||
time = prime / time_index
|
||||
time_index = "[" "t" __ pm __ const "]"
|
||||
time_index = "(" ("t" / const) ")"
|
||||
prime = "'"
|
||||
|
||||
AP = ~r"[a-zA-z\d]+"
|
||||
AP = id time?
|
||||
|
||||
pm = "+" / "-"
|
||||
dt = "dt"
|
||||
|
|
@ -118,7 +118,7 @@ class STLVisitor(NodeVisitor):
|
|||
visit_implies = partialmethod(sugar_binop_visitor, op=implies)
|
||||
|
||||
def visit_until(self, _, children):
|
||||
_, _, phi1, _, _, i, _, phi2, *_ = children
|
||||
phi1, _, _, i, _, phi2 = children
|
||||
i = self.default_interval if not i else i[0]
|
||||
return ast.Until(i, phi1, phi2)
|
||||
|
||||
|
|
@ -134,7 +134,8 @@ class STLVisitor(NodeVisitor):
|
|||
return iden, time
|
||||
|
||||
def visit_time_index(self, _, children):
|
||||
return children[3]* children[5]
|
||||
children = list(flatten(children))
|
||||
return children[0] if children else ast.t_sym
|
||||
|
||||
def visit_prime(self, *_):
|
||||
return ast.t_sym + ast.dt_sym
|
||||
|
|
@ -173,8 +174,8 @@ class STLVisitor(NodeVisitor):
|
|||
def visit_pm(self, node, _):
|
||||
return Number(1) if node.text == "+" else Number(-1)
|
||||
|
||||
def visit_AP(self, node, _):
|
||||
return ast.AtomicPred(node.text, ast.t_sym)
|
||||
def visit_AP(self, *args):
|
||||
return ast.AtomicPred(*self.visit_var(*args))
|
||||
|
||||
def visit_neg(self, _, children):
|
||||
return ast.Neg(children[1])
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class TestSTLEval(unittest.TestCase):
|
|||
self.assertEqual(stl_eval2(x, 0), not r)
|
||||
|
||||
|
||||
|
||||
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
|
||||
def test_fasteval(self, phi_str, _):
|
||||
phi = stl.parse(phi_str)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ ex1_ = ('x1 > a?', stl.LinEq(
|
|||
Symbol("a?")
|
||||
))
|
||||
|
||||
ex1__ = ('x1', stl.AtomicPred('x1', stl.t_sym))
|
||||
ex1__ = ('x1', stl.AtomicPred(Symbol('x1'), stl.t_sym))
|
||||
|
||||
i1 = stl.Interval(0., 1.)
|
||||
i1_ = stl.Interval(0., Symbol("b?"))
|
||||
|
|
|
|||
|
|
@ -67,6 +67,23 @@ class TestSTLUtils(unittest.TestCase):
|
|||
phi = stl.parse("x")
|
||||
self.assertEqual(phi, stl.orf(phi))
|
||||
|
||||
def test_inline_context(self):
|
||||
context = {
|
||||
stl.parse("x"): stl.parse("(z) & (y)"),
|
||||
stl.parse("z"): stl.parse("y - x > 4")
|
||||
}
|
||||
context2 = {
|
||||
stl.parse("x"): stl.parse("x"),
|
||||
}
|
||||
phi = stl.parse("x")
|
||||
self.assertEqual(stl.utils.inline_context(phi, {}), phi)
|
||||
self.assertEqual(stl.utils.inline_context(phi, context),
|
||||
stl.parse("(y - x > 4) & (y)"))
|
||||
|
||||
phi2 = stl.parse("((x) & (z)) | (y)")
|
||||
self.assertEqual(stl.utils.inline_context(phi2, context),
|
||||
stl.parse("((y - x > 4) & (y) & (y - x > 4)) | (y)"))
|
||||
|
||||
# def test_to_from_mtl(self):
|
||||
# raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -128,6 +128,14 @@ def linear_stl_lipschitz(phi):
|
|||
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
||||
return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all())))
|
||||
|
||||
def inline_context(phi, context):
|
||||
phi2 = None
|
||||
update = lambda ap: context.get(ap, ap)
|
||||
while phi2 != phi:
|
||||
phi2, phi = phi, AP_lens(phi).modify(update)
|
||||
# TODO: this is hack to flatten the AST. Fix!
|
||||
return stl.parse(str(phi))
|
||||
|
||||
# EDSL
|
||||
|
||||
def alw(phi, *, lo, hi):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue