From 92ff8c1526cf1ef3944da79c568f4317beb7a294 Mon Sep 17 00:00:00 2001 From: Shromona MacBook Date: Thu, 1 Dec 2016 18:29:50 -0800 Subject: [PATCH] Fast boolean evaluation --- stl/fastboolean_eval.py | 117 ++++++++++++++++++++++++++++++++++++++++ test_boolean.py | 26 +++++++++ 2 files changed, 143 insertions(+) create mode 100644 stl/fastboolean_eval.py create mode 100644 test_boolean.py diff --git a/stl/fastboolean_eval.py b/stl/fastboolean_eval.py new file mode 100644 index 0000000..a22b5ba --- /dev/null +++ b/stl/fastboolean_eval.py @@ -0,0 +1,117 @@ +# TODO: figure out how to deduplicate this with robustness +# - Abstract as working on distributive lattice + +from functools import singledispatch +import operator as op + +import numpy as np +import sympy as smp +from lenses import lens +import gmpy2 as gp + +import stl.ast + +@singledispatch +def pointwise_sat(stl): + raise NotImplementedError + + +@pointwise_sat.register(stl.Or) +def _(stl): + def sat_comp(x,t): + val = 0 + for arg in stl.args: + val = pointwise_sat(arg)(x, t) | val + return val + return sat_comp + #return lambda x, t: any(pointwise_sat(arg)(x, t) for arg in stl.args) + + +@pointwise_sat.register(stl.And) +def _(stl): + def sat_comp(x,t): + val = 2**(len(t))-1 + for arg in stl.args: + val = pointwise_sat(arg)(x, t) & val + return val + return sat_comp + #return lambda x, t: all(pointwise_sat(arg)(x, t) for arg in stl.args) + + +@pointwise_sat.register(stl.F) +def _(stl): + lo, hi = stl.interval + def sat_comp(x,t): + val = 0 + for tau in t: + tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] + val = (val << 1) | (pointwise_sat(stl.arg)(x, tau_t) > 0) + return val + return sat_comp + #return lambda x, t, val: [pointwise_sat(stl.arg)(x, [min(deltat + t2, x.index[-1]) + # for t2 in x[lo:hi].index], 0) for deltat in t] + + +@pointwise_sat.register(stl.G) +def _(stl): + lo, hi = stl.interval + def sat_comp(x,t): + val = 0 + for tau in t: + tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] + val = (val << 1) | (gp.popcount(pointwise_sat(stl.arg)(x, tau_t)) == len(tau_t)) + return val + return sat_comp + #return lambda x, t: all((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1])) + # for t2 in x[lo:hi].index)) + + +@pointwise_sat.register(stl.Neg) +def _(stl): + def sat_comp(x,t): + val = pointwise_sat(arg)(x, t) ^ (2**(len(t))-1) + return val + return sat_comp + #return lambda x, t: pointwise_sat(arg)(x, t, val) + + +op_lookup = { + ">": op.gt, + ">=": op.ge, + "<": op.lt, + "<=": op.le, + "=": op.eq, +} + + +@pointwise_sat.register(stl.AtomicPred) +def _(stl): + def sat_comp(x, t): + val = 0 + for tau in t: + val = (val << 1) | (1 if x[stl.id][tau] else 0) + return val + return sat_comp + #return lambda x, t, val: [(val << 1) | (x[stl.id][deltat] == True) for deltat in t] + + +@pointwise_sat.register(stl.LinEq) +def _(stl): + op = op_lookup[stl.op] + def sat_comp(x, t): + val = 0 + for tau in t: + val = (val << 1) | (op(eval_terms(stl, x, tau), stl.const) == True) + return val + return sat_comp + #return lambda x, t, val: [(val << 1) |(op(eval_terms(stl, x, deltat), stl.const) == True) for deltat in t] + + +def eval_terms(lineq, x, t): + psi = lens(lineq).terms.each_().modify(eval_term(x, t)) + return sum(psi.terms) + + +def eval_term(x, t): + # TODO(lift interpolation much higher) + return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name]) diff --git a/test_boolean.py b/test_boolean.py new file mode 100644 index 0000000..1106060 --- /dev/null +++ b/test_boolean.py @@ -0,0 +1,26 @@ +import stl +import stl.fastboolean_eval +import pandas as pd +from nose2.tools import params +import unittest +from sympy import Symbol + +ex1 = ("2*A > 3", False) +ex2 = ("F[0, 1](2*A > 3)", True) +ex3 = ("F[1, 0](2*A > 3)", False) +ex4 = ("G[1, 0](2*A > 3)", True) +ex5 = ("(A < 0)", False) +ex6 = ("G[0, 0.1](A < 0)", False) +ex7 = ("G[0, 0.1](C)", True) +ex8 = ("G[0, 0.2](C)", False) +ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True) +x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2], + columns=["A", "B", "C"]) + +tests = [ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9] +for test in tests: + phi = stl.parse(test[0]) + print(phi) + stl_eval = stl.fastboolean_eval.pointwise_sat(phi) + print(stl_eval(x, [0])) + \ No newline at end of file