diff --git a/stl/__init__.py b/stl/__init__.py index b4b073b..a3574eb 100644 --- a/stl/__init__.py +++ b/stl/__init__.py @@ -4,3 +4,4 @@ from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var from stl.parser import parse from stl.synth import lex_param_project from stl.boolean_eval import pointwise_sat +from stl.fastboolean_eval import pointwise_satf diff --git a/stl/fastboolean_eval.py b/stl/fastboolean_eval.py new file mode 100644 index 0000000..3fb40c2 --- /dev/null +++ b/stl/fastboolean_eval.py @@ -0,0 +1,106 @@ +# TODO: figure out how to deduplicate this with robustness +# - Abstract as working on distributive lattice + +from functools import singledispatch +import operator as op + +import numpy as np +import sympy as smp +from lenses import lens +import gmpy2 as gp +from bitarray import bitarray + +import stl.ast + +@singledispatch +def pointwise_satf(stl): + raise NotImplementedError + + +@pointwise_satf.register(stl.Or) +def _(stl): + def sat_comp(x,t): + sat = bitarray(len(t)) + for arg in stl.args: + sat = pointwise_satf(arg)(x, t) | sat + return sat + return sat_comp + + +@pointwise_satf.register(stl.And) +def _(stl): + def sat_comp(x,t): + sat = bitarray(len(t)) + sat.setall('True') + for arg in stl.args: + sat = pointwise_satf(arg)(x, t) & sat + return sat + return sat_comp + + +@pointwise_satf.register(stl.F) +def _(stl): + lo, hi = stl.interval + def sat_comp(x,t): + sat = bitarray() + for tau in t: + tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] + sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0) + return sat + return sat_comp + + +@pointwise_satf.register(stl.G) +def _(stl): + lo, hi = stl.interval + def sat_comp(x,t): + sat = bitarray() + for tau in t: + tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] + point_sat = pointwise_satf(stl.arg)(x, tau_t) + sat.append(point_sat.count() == point_sat.length()) + return sat + return sat_comp + + +@pointwise_satf.register(stl.Neg) +def _(stl): + return lambda x,t: ~pointwise_satf(arg)(x, t) + + +op_lookup = { + ">": op.gt, + ">=": op.ge, + "<": op.lt, + "<=": op.le, + "=": op.eq, +} + + +@pointwise_satf.register(stl.AtomicPred) +def _(stl): + def sat_comp(x, t): + sat = bitarray() + [sat.append(x[stl.id][tau]) for tau in t] + return sat + return sat_comp + + +@pointwise_satf.register(stl.LinEq) +def _(stl): + op = op_lookup[stl.op] + def sat_comp(x, t): + sat = bitarray() + [sat.append(op(eval_terms(stl, x, tau), stl.const)) for tau in t] + return sat + return sat_comp + + +def eval_terms(lineq, x, t): + psi = lens(lineq).terms.each_().modify(eval_term(x, t)) + return sum(psi.terms) + + +def eval_term(x, t): + # TODO(lift interpolation much higher) + return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name]) diff --git a/stl/test_fastboolean_eval.py b/stl/test_fastboolean_eval.py new file mode 100644 index 0000000..2d0467f --- /dev/null +++ b/stl/test_fastboolean_eval.py @@ -0,0 +1,25 @@ +import stl +import stl.fastboolean_eval +import pandas as pd +from nose2.tools import params +import unittest +from sympy import Symbol + +ex1 = ("2*A > 3", False) +ex2 = ("F[0, 1](2*A > 3)", True) +ex3 = ("F[1, 0](2*A > 3)", False) +ex4 = ("G[1, 0](2*A > 3)", True) +ex5 = ("(A < 0)", False) +ex6 = ("G[0, 0.1](A < 0)", False) +ex7 = ("G[0, 0.1](C)", True) +ex8 = ("G[0, 0.2](C)", False) +ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True) +x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2], + columns=["A", "B", "C"]) + +class TestSTLRobustness(unittest.TestCase): + @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9) + def test_stl(self, phi_str, r): + phi = stl.parse(phi_str) + stl_eval = stl.fastboolean_eval.pointwise_satf(phi) + self.assertEqual(stl_eval(x, [0]), r) diff --git a/test_boolean.py b/test_boolean.py new file mode 100644 index 0000000..1106060 --- /dev/null +++ b/test_boolean.py @@ -0,0 +1,26 @@ +import stl +import stl.fastboolean_eval +import pandas as pd +from nose2.tools import params +import unittest +from sympy import Symbol + +ex1 = ("2*A > 3", False) +ex2 = ("F[0, 1](2*A > 3)", True) +ex3 = ("F[1, 0](2*A > 3)", False) +ex4 = ("G[1, 0](2*A > 3)", True) +ex5 = ("(A < 0)", False) +ex6 = ("G[0, 0.1](A < 0)", False) +ex7 = ("G[0, 0.1](C)", True) +ex8 = ("G[0, 0.2](C)", False) +ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True) +x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2], + columns=["A", "B", "C"]) + +tests = [ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9] +for test in tests: + phi = stl.parse(test[0]) + print(phi) + stl_eval = stl.fastboolean_eval.pointwise_sat(phi) + print(stl_eval(x, [0])) + \ No newline at end of file