Fast boolean eval

This commit is contained in:
Shromona MacBook 2016-12-03 01:14:55 -08:00
parent fb80b0c3ce
commit d5f38e27ed
2 changed files with 47 additions and 24 deletions

View file

@ -8,6 +8,7 @@ import numpy as np
import sympy as smp import sympy as smp
from lenses import lens from lenses import lens
import gmpy2 as gp import gmpy2 as gp
from bitarray import bitarray
import stl.ast import stl.ast
@ -19,20 +20,21 @@ def pointwise_sat(stl):
@pointwise_sat.register(stl.Or) @pointwise_sat.register(stl.Or)
def _(stl): def _(stl):
def sat_comp(x,t): def sat_comp(x,t):
val = 0 sat = bitarray(len(t))
for arg in stl.args: for arg in stl.args:
val = pointwise_sat(arg)(x, t) | val sat = pointwise_sat(arg)(x, t) | sat
return val return sat
return sat_comp return sat_comp
@pointwise_sat.register(stl.And) @pointwise_sat.register(stl.And)
def _(stl): def _(stl):
def sat_comp(x,t): def sat_comp(x,t):
val = 2**(len(t))-1 sat = bitarray(len(t))
sat.setall('True')
for arg in stl.args: for arg in stl.args:
val = pointwise_sat(arg)(x, t) & val sat = pointwise_sat(arg)(x, t) & sat
return val return sat
return sat_comp return sat_comp
@ -40,11 +42,11 @@ def _(stl):
def _(stl): def _(stl):
lo, hi = stl.interval lo, hi = stl.interval
def sat_comp(x,t): def sat_comp(x,t):
val = 0 sat = bitarray()
for tau in t: for tau in t:
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
val = (val << 1) | (pointwise_sat(stl.arg)(x, tau_t) > 0) sat.append((pointwise_sat(stl.arg)(x, tau_t)).count() > 0)
return val return sat
return sat_comp return sat_comp
@ -52,20 +54,18 @@ def _(stl):
def _(stl): def _(stl):
lo, hi = stl.interval lo, hi = stl.interval
def sat_comp(x,t): def sat_comp(x,t):
val = 0 sat = bitarray()
for tau in t: for tau in t:
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
val = (val << 1) | (gp.popcount(pointwise_sat(stl.arg)(x, tau_t)) == len(tau_t)) point_sat = pointwise_sat(stl.arg)(x, tau_t)
return val sat.append(point_sat.count() == point_sat.length())
return sat
return sat_comp return sat_comp
@pointwise_sat.register(stl.Neg) @pointwise_sat.register(stl.Neg)
def _(stl): def _(stl):
def sat_comp(x,t): return lambda x,t: ~pointwise_sat(arg)(x, t)
val = pointwise_sat(arg)(x, t) ^ (2**(len(t))-1)
return val
return sat_comp
op_lookup = { op_lookup = {
@ -80,10 +80,9 @@ op_lookup = {
@pointwise_sat.register(stl.AtomicPred) @pointwise_sat.register(stl.AtomicPred)
def _(stl): def _(stl):
def sat_comp(x, t): def sat_comp(x, t):
val = 0 sat = bitarray()
for tau in t: [sat.append(x[stl.id][tau]) for tau in t]
val = (val << 1) | (1 if x[stl.id][tau] else 0) return sat
return val
return sat_comp return sat_comp
@ -91,10 +90,9 @@ def _(stl):
def _(stl): def _(stl):
op = op_lookup[stl.op] op = op_lookup[stl.op]
def sat_comp(x, t): def sat_comp(x, t):
val = 0 sat = bitarray()
for tau in t: [sat.append(op(eval_terms(stl, x, tau), stl.const)) for tau in t]
val = (val << 1) | (op(eval_terms(stl, x, tau), stl.const) == True) return sat
return val
return sat_comp return sat_comp

View file

@ -0,0 +1,25 @@
import stl
import stl.fastboolean_eval
import pandas as pd
from nose2.tools import params
import unittest
from sympy import Symbol
ex1 = ("2*A > 3", False)
ex2 = ("F[0, 1](2*A > 3)", True)
ex3 = ("F[1, 0](2*A > 3)", False)
ex4 = ("G[1, 0](2*A > 3)", True)
ex5 = ("(A < 0)", False)
ex6 = ("G[0, 0.1](A < 0)", False)
ex7 = ("G[0, 0.1](C)", True)
ex8 = ("G[0, 0.2](C)", False)
ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True)
x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2],
columns=["A", "B", "C"])
class TestSTLRobustness(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
def test_stl(self, phi_str, r):
phi = stl.parse(phi_str)
stl_eval = stl.fastboolean_eval.pointwise_sat(phi)
self.assertEqual(stl_eval(x, [0]), r)