diff --git a/stl/fastboolean_eval.py b/stl/fastboolean_eval.py index 6c8091e..f6d69b4 100644 --- a/stl/fastboolean_eval.py +++ b/stl/fastboolean_eval.py @@ -8,6 +8,7 @@ import numpy as np import sympy as smp from lenses import lens import gmpy2 as gp +from bitarray import bitarray import stl.ast @@ -19,20 +20,21 @@ def pointwise_sat(stl): @pointwise_sat.register(stl.Or) def _(stl): def sat_comp(x,t): - val = 0 + sat = bitarray(len(t)) for arg in stl.args: - val = pointwise_sat(arg)(x, t) | val - return val + sat = pointwise_sat(arg)(x, t) | sat + return sat return sat_comp @pointwise_sat.register(stl.And) def _(stl): def sat_comp(x,t): - val = 2**(len(t))-1 + sat = bitarray(len(t)) + sat.setall('True') for arg in stl.args: - val = pointwise_sat(arg)(x, t) & val - return val + sat = pointwise_sat(arg)(x, t) & sat + return sat return sat_comp @@ -40,11 +42,11 @@ def _(stl): def _(stl): lo, hi = stl.interval def sat_comp(x,t): - val = 0 + sat = bitarray() for tau in t: tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] - val = (val << 1) | (pointwise_sat(stl.arg)(x, tau_t) > 0) - return val + sat.append((pointwise_sat(stl.arg)(x, tau_t)).count() > 0) + return sat return sat_comp @@ -52,20 +54,18 @@ def _(stl): def _(stl): lo, hi = stl.interval def sat_comp(x,t): - val = 0 + sat = bitarray() for tau in t: tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] - val = (val << 1) | (gp.popcount(pointwise_sat(stl.arg)(x, tau_t)) == len(tau_t)) - return val + point_sat = pointwise_sat(stl.arg)(x, tau_t) + sat.append(point_sat.count() == point_sat.length()) + return sat return sat_comp @pointwise_sat.register(stl.Neg) def _(stl): - def sat_comp(x,t): - val = pointwise_sat(arg)(x, t) ^ (2**(len(t))-1) - return val - return sat_comp + return lambda x,t: ~pointwise_sat(arg)(x, t) op_lookup = { @@ -80,10 +80,9 @@ op_lookup = { @pointwise_sat.register(stl.AtomicPred) def _(stl): def sat_comp(x, t): - val = 0 - for tau in t: - val = (val << 1) | (1 if x[stl.id][tau] else 0) - return val + sat = bitarray() + [sat.append(x[stl.id][tau]) for tau in t] + return sat return sat_comp @@ -91,10 +90,9 @@ def _(stl): def _(stl): op = op_lookup[stl.op] def sat_comp(x, t): - val = 0 - for tau in t: - val = (val << 1) | (op(eval_terms(stl, x, tau), stl.const) == True) - return val + sat = bitarray() + [sat.append(op(eval_terms(stl, x, tau), stl.const)) for tau in t] + return sat return sat_comp diff --git a/stl/test_fastboolean_eval.py b/stl/test_fastboolean_eval.py new file mode 100644 index 0000000..043d754 --- /dev/null +++ b/stl/test_fastboolean_eval.py @@ -0,0 +1,25 @@ +import stl +import stl.fastboolean_eval +import pandas as pd +from nose2.tools import params +import unittest +from sympy import Symbol + +ex1 = ("2*A > 3", False) +ex2 = ("F[0, 1](2*A > 3)", True) +ex3 = ("F[1, 0](2*A > 3)", False) +ex4 = ("G[1, 0](2*A > 3)", True) +ex5 = ("(A < 0)", False) +ex6 = ("G[0, 0.1](A < 0)", False) +ex7 = ("G[0, 0.1](C)", True) +ex8 = ("G[0, 0.2](C)", False) +ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True) +x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2], + columns=["A", "B", "C"]) + +class TestSTLRobustness(unittest.TestCase): + @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9) + def test_stl(self, phi_str, r): + phi = stl.parse(phi_str) + stl_eval = stl.fastboolean_eval.pointwise_sat(phi) + self.assertEqual(stl_eval(x, [0]), r)