diff --git a/stl/__init__.py b/stl/__init__.py index b4b073b..a3574eb 100644 --- a/stl/__init__.py +++ b/stl/__init__.py @@ -4,3 +4,4 @@ from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var from stl.parser import parse from stl.synth import lex_param_project from stl.boolean_eval import pointwise_sat +from stl.fastboolean_eval import pointwise_satf diff --git a/stl/fastboolean_eval.py b/stl/fastboolean_eval.py index f6d69b4..3fb40c2 100644 --- a/stl/fastboolean_eval.py +++ b/stl/fastboolean_eval.py @@ -13,59 +13,59 @@ from bitarray import bitarray import stl.ast @singledispatch -def pointwise_sat(stl): +def pointwise_satf(stl): raise NotImplementedError -@pointwise_sat.register(stl.Or) +@pointwise_satf.register(stl.Or) def _(stl): def sat_comp(x,t): sat = bitarray(len(t)) for arg in stl.args: - sat = pointwise_sat(arg)(x, t) | sat + sat = pointwise_satf(arg)(x, t) | sat return sat return sat_comp -@pointwise_sat.register(stl.And) +@pointwise_satf.register(stl.And) def _(stl): def sat_comp(x,t): sat = bitarray(len(t)) sat.setall('True') for arg in stl.args: - sat = pointwise_sat(arg)(x, t) & sat + sat = pointwise_satf(arg)(x, t) & sat return sat return sat_comp -@pointwise_sat.register(stl.F) +@pointwise_satf.register(stl.F) def _(stl): lo, hi = stl.interval def sat_comp(x,t): sat = bitarray() for tau in t: tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] - sat.append((pointwise_sat(stl.arg)(x, tau_t)).count() > 0) + sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0) return sat return sat_comp -@pointwise_sat.register(stl.G) +@pointwise_satf.register(stl.G) def _(stl): lo, hi = stl.interval def sat_comp(x,t): sat = bitarray() for tau in t: tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] - point_sat = pointwise_sat(stl.arg)(x, tau_t) + point_sat = pointwise_satf(stl.arg)(x, tau_t) sat.append(point_sat.count() == point_sat.length()) return sat return sat_comp -@pointwise_sat.register(stl.Neg) +@pointwise_satf.register(stl.Neg) def _(stl): - return lambda x,t: ~pointwise_sat(arg)(x, t) + return lambda x,t: ~pointwise_satf(arg)(x, t) op_lookup = { @@ -77,7 +77,7 @@ op_lookup = { } -@pointwise_sat.register(stl.AtomicPred) +@pointwise_satf.register(stl.AtomicPred) def _(stl): def sat_comp(x, t): sat = bitarray() @@ -86,7 +86,7 @@ def _(stl): return sat_comp -@pointwise_sat.register(stl.LinEq) +@pointwise_satf.register(stl.LinEq) def _(stl): op = op_lookup[stl.op] def sat_comp(x, t): diff --git a/stl/test_fastboolean_eval.py b/stl/test_fastboolean_eval.py index 043d754..2d0467f 100644 --- a/stl/test_fastboolean_eval.py +++ b/stl/test_fastboolean_eval.py @@ -21,5 +21,5 @@ class TestSTLRobustness(unittest.TestCase): @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9) def test_stl(self, phi_str, r): phi = stl.parse(phi_str) - stl_eval = stl.fastboolean_eval.pointwise_sat(phi) + stl_eval = stl.fastboolean_eval.pointwise_satf(phi) self.assertEqual(stl_eval(x, [0]), r)