diff --git a/boolean_eval.py b/boolean_eval.py new file mode 100644 index 0000000..fbbe51e --- /dev/null +++ b/boolean_eval.py @@ -0,0 +1,66 @@ +# TODO: figure out how to deduplicate this with robustness +# - Abstract as working on distributive lattice + +from functools import singledispatch +import operator as op + +from lenses import lens + +import stl.ast + +@singledispatch +def pointwise_sat(stl): + raise NotImplementedError + + +@pointwise_sat.register(stl.Or) +def _(stl): + return lambda x, t: any(pointwise_sat(arg)(x, t) for arg in stl.args) + + +@pointwise_sat.register(stl.And) +def _(stl): + return lambda x, t: all(pointwise_sat(arg)(x, t) for arg in stl.args) + + +@pointwise_sat.register(stl.F) +def _(stl): + lo, hi = stl.interval + return lambda x, t: any((pointwise_sat(stl.arg)(x, t + t2) + for t2 in x[lo:hi].index)) + + +@pointwise_sat.register(stl.G) +def _(stl): + lo, hi = stl.interval + return lambda x, t: all((pointwise_sat(stl.arg)(x, t + t2) + for t2 in x[lo:hi].index)) + + +@pointwise_sat.register(stl.Neg) +def _(stl): + return lambda x, t: not pointwise_sat(arg)(x, t) + + +op_lookup = { + ">": op.gt, + ">=": op.ge, + "<": op.lt, + "<=": op.le, + "=": op.eq, +} + + +@pointwise_sat.register(stl.LinEq) +def _(stl): + op = op_lookup[stl.op] + return lambda x, t: op(eval_terms(stl, x, t), stl.const) + + +def eval_terms(lineq, x, t): + psi = lens(lineq).terms.each_().modify(eval_term(x, t)) + return sum(psi.terms) + + +def eval_term(x, t): + return lambda term: term.coeff*x[term.id.name][t] diff --git a/test_boolean_eval.py b/test_boolean_eval.py new file mode 100644 index 0000000..de5cbbe --- /dev/null +++ b/test_boolean_eval.py @@ -0,0 +1,22 @@ +import stl +import stl.boolean_eval +import pandas as pd +from nose2.tools import params +import unittest +from sympy import Symbol + +ex1 = ("2*A > 3", False) +ex2 = ("F[0, 1](2*A > 3)", True) +ex3 = ("F[1, 0](2*A > 3)", False) +ex4 = ("G[1, 0](2*A > 3)", True) +ex5 = ("(A < 0)", False) +ex6 = ("G[0, 0.1](A < 0)", False) +x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2], + columns=["A", "B"]) + +class TestSTLRobustness(unittest.TestCase): + @params(ex1, ex2, ex3, ex4, ex5, ex6) + def test_stl(self, phi_str, r): + phi = stl.parse(phi_str) + stl_eval = stl.boolean_eval.pointwise_sat(phi) + self.assertEqual(stl_eval(x, 0), r)