implemented context inlining + fix parsing errors
This commit is contained in:
parent
0d4a703dcc
commit
476917860e
8 changed files with 40 additions and 13 deletions
|
|
@ -61,7 +61,7 @@ BOT = _Bot()
|
||||||
|
|
||||||
class AtomicPred(namedtuple("AP", ["id", "time"]), AST):
|
class AtomicPred(namedtuple("AP", ["id", "time"]), AST):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.id}[{self.time}]"
|
return f"{self.id}({self.time})"
|
||||||
|
|
||||||
def children(self):
|
def children(self):
|
||||||
return []
|
return []
|
||||||
|
|
@ -81,7 +81,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
|
||||||
|
|
||||||
class Var(namedtuple("Var", ["coeff", "id", "time"])):
|
class Var(namedtuple("Var", ["coeff", "id", "time"])):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.coeff}*{self.id}[{self.time}]"
|
return f"{self.coeff}*{self.id}({self.time})"
|
||||||
|
|
||||||
|
|
||||||
class Interval(namedtuple('I', ['lower', 'upper'])):
|
class Interval(namedtuple('I', ['lower', 'upper'])):
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ op_lookup = {
|
||||||
|
|
||||||
@pointwise_sat.register(stl.AtomicPred)
|
@pointwise_sat.register(stl.AtomicPred)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
return lambda x, t: x[stl.id][t]
|
return lambda x, t: x[str(stl.id)][t]
|
||||||
|
|
||||||
|
|
||||||
@pointwise_sat.register(stl.LinEq)
|
@pointwise_sat.register(stl.LinEq)
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ def _(stl):
|
||||||
def _(stl):
|
def _(stl):
|
||||||
def sat_comp(x, t):
|
def sat_comp(x, t):
|
||||||
sat = bitarray()
|
sat = bitarray()
|
||||||
[sat.append(x[stl.id][tau]) for tau in t]
|
[sat.append(x[str(stl.id)][tau]) for tau in t]
|
||||||
return sat
|
return sat
|
||||||
return sat_comp
|
return sat_comp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from stl import ast
|
||||||
from stl.utils import implies, xor, iff
|
from stl.utils import implies, xor, iff
|
||||||
|
|
||||||
STL_GRAMMAR = Grammar(u'''
|
STL_GRAMMAR = Grammar(u'''
|
||||||
phi = (neg / g / f / until / lineq / AP / or / and / implies / xor / iff / paren_phi)
|
phi = (until / neg / g / f / lineq / AP / or / and / implies / xor / iff / paren_phi)
|
||||||
|
|
||||||
paren_phi = "(" __ phi __ ")"
|
paren_phi = "(" __ phi __ ")"
|
||||||
|
|
||||||
|
|
@ -35,7 +35,7 @@ neg = ("~" / "¬") phi
|
||||||
|
|
||||||
f = F interval? phi
|
f = F interval? phi
|
||||||
g = G interval? phi
|
g = G interval? phi
|
||||||
until = "(" __ phi _ U interval? _ phi __ ")"
|
until = paren_phi __ U interval? __ paren_phi
|
||||||
|
|
||||||
F = "F" / "◇"
|
F = "F" / "◇"
|
||||||
G = "G" / "□"
|
G = "G" / "□"
|
||||||
|
|
@ -52,10 +52,10 @@ terms = (term __ pm __ terms) / term
|
||||||
|
|
||||||
var = id time?
|
var = id time?
|
||||||
time = prime / time_index
|
time = prime / time_index
|
||||||
time_index = "[" "t" __ pm __ const "]"
|
time_index = "(" ("t" / const) ")"
|
||||||
prime = "'"
|
prime = "'"
|
||||||
|
|
||||||
AP = ~r"[a-zA-z\d]+"
|
AP = id time?
|
||||||
|
|
||||||
pm = "+" / "-"
|
pm = "+" / "-"
|
||||||
dt = "dt"
|
dt = "dt"
|
||||||
|
|
@ -118,7 +118,7 @@ class STLVisitor(NodeVisitor):
|
||||||
visit_implies = partialmethod(sugar_binop_visitor, op=implies)
|
visit_implies = partialmethod(sugar_binop_visitor, op=implies)
|
||||||
|
|
||||||
def visit_until(self, _, children):
|
def visit_until(self, _, children):
|
||||||
_, _, phi1, _, _, i, _, phi2, *_ = children
|
phi1, _, _, i, _, phi2 = children
|
||||||
i = self.default_interval if not i else i[0]
|
i = self.default_interval if not i else i[0]
|
||||||
return ast.Until(i, phi1, phi2)
|
return ast.Until(i, phi1, phi2)
|
||||||
|
|
||||||
|
|
@ -134,7 +134,8 @@ class STLVisitor(NodeVisitor):
|
||||||
return iden, time
|
return iden, time
|
||||||
|
|
||||||
def visit_time_index(self, _, children):
|
def visit_time_index(self, _, children):
|
||||||
return children[3]* children[5]
|
children = list(flatten(children))
|
||||||
|
return children[0] if children else ast.t_sym
|
||||||
|
|
||||||
def visit_prime(self, *_):
|
def visit_prime(self, *_):
|
||||||
return ast.t_sym + ast.dt_sym
|
return ast.t_sym + ast.dt_sym
|
||||||
|
|
@ -173,8 +174,8 @@ class STLVisitor(NodeVisitor):
|
||||||
def visit_pm(self, node, _):
|
def visit_pm(self, node, _):
|
||||||
return Number(1) if node.text == "+" else Number(-1)
|
return Number(1) if node.text == "+" else Number(-1)
|
||||||
|
|
||||||
def visit_AP(self, node, _):
|
def visit_AP(self, *args):
|
||||||
return ast.AtomicPred(node.text, ast.t_sym)
|
return ast.AtomicPred(*self.visit_var(*args))
|
||||||
|
|
||||||
def visit_neg(self, _, children):
|
def visit_neg(self, _, children):
|
||||||
return ast.Neg(children[1])
|
return ast.Neg(children[1])
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ class TestSTLEval(unittest.TestCase):
|
||||||
self.assertEqual(stl_eval2(x, 0), not r)
|
self.assertEqual(stl_eval2(x, 0), not r)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
|
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
|
||||||
def test_fasteval(self, phi_str, _):
|
def test_fasteval(self, phi_str, _):
|
||||||
phi = stl.parse(phi_str)
|
phi = stl.parse(phi_str)
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ ex1_ = ('x1 > a?', stl.LinEq(
|
||||||
Symbol("a?")
|
Symbol("a?")
|
||||||
))
|
))
|
||||||
|
|
||||||
ex1__ = ('x1', stl.AtomicPred('x1', stl.t_sym))
|
ex1__ = ('x1', stl.AtomicPred(Symbol('x1'), stl.t_sym))
|
||||||
|
|
||||||
i1 = stl.Interval(0., 1.)
|
i1 = stl.Interval(0., 1.)
|
||||||
i1_ = stl.Interval(0., Symbol("b?"))
|
i1_ = stl.Interval(0., Symbol("b?"))
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,23 @@ class TestSTLUtils(unittest.TestCase):
|
||||||
phi = stl.parse("x")
|
phi = stl.parse("x")
|
||||||
self.assertEqual(phi, stl.orf(phi))
|
self.assertEqual(phi, stl.orf(phi))
|
||||||
|
|
||||||
|
def test_inline_context(self):
|
||||||
|
context = {
|
||||||
|
stl.parse("x"): stl.parse("(z) & (y)"),
|
||||||
|
stl.parse("z"): stl.parse("y - x > 4")
|
||||||
|
}
|
||||||
|
context2 = {
|
||||||
|
stl.parse("x"): stl.parse("x"),
|
||||||
|
}
|
||||||
|
phi = stl.parse("x")
|
||||||
|
self.assertEqual(stl.utils.inline_context(phi, {}), phi)
|
||||||
|
self.assertEqual(stl.utils.inline_context(phi, context),
|
||||||
|
stl.parse("(y - x > 4) & (y)"))
|
||||||
|
|
||||||
|
phi2 = stl.parse("((x) & (z)) | (y)")
|
||||||
|
self.assertEqual(stl.utils.inline_context(phi2, context),
|
||||||
|
stl.parse("((y - x > 4) & (y) & (y - x > 4)) | (y)"))
|
||||||
|
|
||||||
# def test_to_from_mtl(self):
|
# def test_to_from_mtl(self):
|
||||||
# raise NotImplementedError
|
# raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,14 @@ def linear_stl_lipschitz(phi):
|
||||||
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
||||||
return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all())))
|
return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all())))
|
||||||
|
|
||||||
|
def inline_context(phi, context):
|
||||||
|
phi2 = None
|
||||||
|
update = lambda ap: context.get(ap, ap)
|
||||||
|
while phi2 != phi:
|
||||||
|
phi2, phi = phi, AP_lens(phi).modify(update)
|
||||||
|
# TODO: this is hack to flatten the AST. Fix!
|
||||||
|
return stl.parse(str(phi))
|
||||||
|
|
||||||
# EDSL
|
# EDSL
|
||||||
|
|
||||||
def alw(phi, *, lo, hi):
|
def alw(phi, *, lo, hi):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue