diff --git a/stl/boolean_eval.py b/stl/boolean_eval.py index e1884fc..ba12b4d 100644 --- a/stl/boolean_eval.py +++ b/stl/boolean_eval.py @@ -26,28 +26,34 @@ def _(stl): return lambda x, t: all(f(x, t) for f in fs) +def get_times(x, tau, lo=None, hi=None): + indices = x.index if lo is None or hi is None else x[lo:hi].index + return [min(tau + t2, x.index[-1]) for t2 in indices] + + @pointwise_sat.register(stl.Until) def _(stl): def _until(x, t): - phi = (pointwise_sat(phi)(x, t) for t in x.index) - return lambda x, t: any((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1])) - for t2 in x[lo:hi].index)) + f1, f2 = pointwise_sat(stl.arg1), pointwise_sat(stl.arg2) + for tau in get_times(x, t): + if not f1(x, tau): + return f2(x, tau) + return False + return _until @pointwise_sat.register(stl.F) def _(stl): lo, hi = stl.interval f = pointwise_sat(stl.arg) - return lambda x, t: any((f(x, min(t + t2, x.index[-1])) - for t2 in x[lo:hi].index)) + return lambda x, t: any(f(x, tau) for tau in get_times(x, t, lo, hi)) @pointwise_sat.register(stl.G) def _(stl): lo, hi = stl.interval - f = pointwise_sat(stl.arg) - return lambda x, t: all((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1])) - for t2 in x[lo:hi].index)) + f = pointwise_sat(stl.arg) + return lambda x, t: all(f(x, tau) for tau in get_times(x, t, lo, hi)) @pointwise_sat.register(stl.Neg) diff --git a/stl/fastboolean_eval.py b/stl/fastboolean_eval.py index 9cbb7be..b2cec06 100644 --- a/stl/fastboolean_eval.py +++ b/stl/fastboolean_eval.py @@ -4,7 +4,7 @@ from operator import and_, or_ from bitarray import bitarray import stl.ast -from stl.boolean_eval import eval_terms, op_lookup +from stl.boolean_eval import eval_terms, op_lookup, get_times def pointwise_sat(stl): f = pointwise_satf(stl) @@ -30,15 +30,11 @@ def _(stl): return bool_op(stl, conjunction=True) -def get_times(x, lo, hi, tau): - return [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] - - def temporal_op(stl, lo, hi, conjunction=False): fold = bitarray.all if conjunction else bitarray.any f = pointwise_satf(stl.arg) def sat_comp(x,t): - return bitarray(fold(f(x, get_times(x, lo, hi, tau))) for tau in t) + return bitarray(fold(f(x, get_times(x, tau, lo, hi))) for tau in t) return sat_comp diff --git a/stl/test_boolean_eval.py b/stl/test_boolean_eval.py index d0bb532..25fab18 100644 --- a/stl/test_boolean_eval.py +++ b/stl/test_boolean_eval.py @@ -15,11 +15,13 @@ ex6 = ("G[0, 0.1](A < 0)", False) ex7 = ("G[0, 0.1](C)", True) ex8 = ("G[0, 0.2](C)", False) ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True) +ex10 = ("(A = 1) U (A = 4)", True) +ex11 = ("(A < 5) U (A = 4)", False) x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2], columns=["A", "B", "C"]) class TestSTLEval(unittest.TestCase): - @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9) + @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9, ex10, ex11) def test_eval(self, phi_str, r): phi = stl.parse(phi_str) stl_eval = stl.boolean_eval.pointwise_sat(phi)