From 476917860e831e7eaf1075d427b9bc20f7b634e8 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Thu, 20 Apr 2017 17:16:15 -0700 Subject: [PATCH] implemented context inlining + fix parsing errors --- stl/ast.py | 4 ++-- stl/boolean_eval.py | 2 +- stl/fastboolean_eval.py | 2 +- stl/parser.py | 17 +++++++++-------- stl/test_boolean_eval.py | 1 + stl/test_parser.py | 2 +- stl/test_utils.py | 17 +++++++++++++++++ stl/utils.py | 8 ++++++++ 8 files changed, 40 insertions(+), 13 deletions(-) diff --git a/stl/ast.py b/stl/ast.py index af1c3a0..80a8d74 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -61,7 +61,7 @@ BOT = _Bot() class AtomicPred(namedtuple("AP", ["id", "time"]), AST): def __repr__(self): - return f"{self.id}[{self.time}]" + return f"{self.id}({self.time})" def children(self): return [] @@ -81,7 +81,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST): class Var(namedtuple("Var", ["coeff", "id", "time"])): def __repr__(self): - return f"{self.coeff}*{self.id}[{self.time}]" + return f"{self.coeff}*{self.id}({self.time})" class Interval(namedtuple('I', ['lower', 'upper'])): diff --git a/stl/boolean_eval.py b/stl/boolean_eval.py index 1481d6f..502c008 100644 --- a/stl/boolean_eval.py +++ b/stl/boolean_eval.py @@ -54,7 +54,7 @@ op_lookup = { @pointwise_sat.register(stl.AtomicPred) def _(stl): - return lambda x, t: x[stl.id][t] + return lambda x, t: x[str(stl.id)][t] @pointwise_sat.register(stl.LinEq) diff --git a/stl/fastboolean_eval.py b/stl/fastboolean_eval.py index e3fb56a..2fcefed 100644 --- a/stl/fastboolean_eval.py +++ b/stl/fastboolean_eval.py @@ -68,7 +68,7 @@ def _(stl): def _(stl): def sat_comp(x, t): sat = bitarray() - [sat.append(x[stl.id][tau]) for tau in t] + [sat.append(x[str(stl.id)][tau]) for tau in t] return sat return sat_comp diff --git a/stl/parser.py b/stl/parser.py index 4c1d2fa..a6f3310 100644 --- a/stl/parser.py +++ b/stl/parser.py @@ -21,7 +21,7 @@ from stl import ast from stl.utils import implies, xor, iff STL_GRAMMAR = Grammar(u''' -phi = (neg / g / f / until / lineq / AP / or / and / implies / xor / iff / paren_phi) +phi = (until / neg / g / f / lineq / AP / or / and / implies / xor / iff / paren_phi) paren_phi = "(" __ phi __ ")" @@ -35,7 +35,7 @@ neg = ("~" / "¬") phi f = F interval? phi g = G interval? phi -until = "(" __ phi _ U interval? _ phi __ ")" +until = paren_phi __ U interval? __ paren_phi F = "F" / "◇" G = "G" / "□" @@ -52,10 +52,10 @@ terms = (term __ pm __ terms) / term var = id time? time = prime / time_index -time_index = "[" "t" __ pm __ const "]" +time_index = "(" ("t" / const) ")" prime = "'" -AP = ~r"[a-zA-z\d]+" +AP = id time? pm = "+" / "-" dt = "dt" @@ -118,7 +118,7 @@ class STLVisitor(NodeVisitor): visit_implies = partialmethod(sugar_binop_visitor, op=implies) def visit_until(self, _, children): - _, _, phi1, _, _, i, _, phi2, *_ = children + phi1, _, _, i, _, phi2 = children i = self.default_interval if not i else i[0] return ast.Until(i, phi1, phi2) @@ -134,7 +134,8 @@ class STLVisitor(NodeVisitor): return iden, time def visit_time_index(self, _, children): - return children[3]* children[5] + children = list(flatten(children)) + return children[0] if children else ast.t_sym def visit_prime(self, *_): return ast.t_sym + ast.dt_sym @@ -173,8 +174,8 @@ class STLVisitor(NodeVisitor): def visit_pm(self, node, _): return Number(1) if node.text == "+" else Number(-1) - def visit_AP(self, node, _): - return ast.AtomicPred(node.text, ast.t_sym) + def visit_AP(self, *args): + return ast.AtomicPred(*self.visit_var(*args)) def visit_neg(self, _, children): return ast.Neg(children[1]) diff --git a/stl/test_boolean_eval.py b/stl/test_boolean_eval.py index 7d83af4..d0bb532 100644 --- a/stl/test_boolean_eval.py +++ b/stl/test_boolean_eval.py @@ -28,6 +28,7 @@ class TestSTLEval(unittest.TestCase): self.assertEqual(stl_eval2(x, 0), not r) + @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9) def test_fasteval(self, phi_str, _): phi = stl.parse(phi_str) diff --git a/stl/test_parser.py b/stl/test_parser.py index 51208a5..28912a5 100644 --- a/stl/test_parser.py +++ b/stl/test_parser.py @@ -15,7 +15,7 @@ ex1_ = ('x1 > a?', stl.LinEq( Symbol("a?") )) -ex1__ = ('x1', stl.AtomicPred('x1', stl.t_sym)) +ex1__ = ('x1', stl.AtomicPred(Symbol('x1'), stl.t_sym)) i1 = stl.Interval(0., 1.) i1_ = stl.Interval(0., Symbol("b?")) diff --git a/stl/test_utils.py b/stl/test_utils.py index 38c2665..44988e8 100644 --- a/stl/test_utils.py +++ b/stl/test_utils.py @@ -67,6 +67,23 @@ class TestSTLUtils(unittest.TestCase): phi = stl.parse("x") self.assertEqual(phi, stl.orf(phi)) + def test_inline_context(self): + context = { + stl.parse("x"): stl.parse("(z) & (y)"), + stl.parse("z"): stl.parse("y - x > 4") + } + context2 = { + stl.parse("x"): stl.parse("x"), + } + phi = stl.parse("x") + self.assertEqual(stl.utils.inline_context(phi, {}), phi) + self.assertEqual(stl.utils.inline_context(phi, context), + stl.parse("(y - x > 4) & (y)")) + + phi2 = stl.parse("((x) & (z)) | (y)") + self.assertEqual(stl.utils.inline_context(phi2, context), + stl.parse("((y - x > 4) & (y) & (y - x > 4)) | (y)")) + # def test_to_from_mtl(self): # raise NotImplementedError diff --git a/stl/utils.py b/stl/utils.py index 79e8d12..c42718f 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -128,6 +128,14 @@ def linear_stl_lipschitz(phi): """Infinity norm lipschitz bound of linear inequality predicate.""" return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all()))) +def inline_context(phi, context): + phi2 = None + update = lambda ap: context.get(ap, ap) + while phi2 != phi: + phi2, phi = phi, AP_lens(phi).modify(update) + # TODO: this is hack to flatten the AST. Fix! + return stl.parse(str(phi)) + # EDSL def alw(phi, *, lo, hi):