Fast boolean with eval

This commit is contained in:
Shromona MacBook 2016-12-03 01:21:00 -08:00
parent d5f38e27ed
commit 8cae363e61
3 changed files with 15 additions and 14 deletions

View file

@ -4,3 +4,4 @@ from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var
from stl.parser import parse from stl.parser import parse
from stl.synth import lex_param_project from stl.synth import lex_param_project
from stl.boolean_eval import pointwise_sat from stl.boolean_eval import pointwise_sat
from stl.fastboolean_eval import pointwise_satf

View file

@ -13,59 +13,59 @@ from bitarray import bitarray
import stl.ast import stl.ast
@singledispatch @singledispatch
def pointwise_sat(stl): def pointwise_satf(stl):
raise NotImplementedError raise NotImplementedError
@pointwise_sat.register(stl.Or) @pointwise_satf.register(stl.Or)
def _(stl): def _(stl):
def sat_comp(x,t): def sat_comp(x,t):
sat = bitarray(len(t)) sat = bitarray(len(t))
for arg in stl.args: for arg in stl.args:
sat = pointwise_sat(arg)(x, t) | sat sat = pointwise_satf(arg)(x, t) | sat
return sat return sat
return sat_comp return sat_comp
@pointwise_sat.register(stl.And) @pointwise_satf.register(stl.And)
def _(stl): def _(stl):
def sat_comp(x,t): def sat_comp(x,t):
sat = bitarray(len(t)) sat = bitarray(len(t))
sat.setall('True') sat.setall('True')
for arg in stl.args: for arg in stl.args:
sat = pointwise_sat(arg)(x, t) & sat sat = pointwise_satf(arg)(x, t) & sat
return sat return sat
return sat_comp return sat_comp
@pointwise_sat.register(stl.F) @pointwise_satf.register(stl.F)
def _(stl): def _(stl):
lo, hi = stl.interval lo, hi = stl.interval
def sat_comp(x,t): def sat_comp(x,t):
sat = bitarray() sat = bitarray()
for tau in t: for tau in t:
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
sat.append((pointwise_sat(stl.arg)(x, tau_t)).count() > 0) sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0)
return sat return sat
return sat_comp return sat_comp
@pointwise_sat.register(stl.G) @pointwise_satf.register(stl.G)
def _(stl): def _(stl):
lo, hi = stl.interval lo, hi = stl.interval
def sat_comp(x,t): def sat_comp(x,t):
sat = bitarray() sat = bitarray()
for tau in t: for tau in t:
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
point_sat = pointwise_sat(stl.arg)(x, tau_t) point_sat = pointwise_satf(stl.arg)(x, tau_t)
sat.append(point_sat.count() == point_sat.length()) sat.append(point_sat.count() == point_sat.length())
return sat return sat
return sat_comp return sat_comp
@pointwise_sat.register(stl.Neg) @pointwise_satf.register(stl.Neg)
def _(stl): def _(stl):
return lambda x,t: ~pointwise_sat(arg)(x, t) return lambda x,t: ~pointwise_satf(arg)(x, t)
op_lookup = { op_lookup = {
@ -77,7 +77,7 @@ op_lookup = {
} }
@pointwise_sat.register(stl.AtomicPred) @pointwise_satf.register(stl.AtomicPred)
def _(stl): def _(stl):
def sat_comp(x, t): def sat_comp(x, t):
sat = bitarray() sat = bitarray()
@ -86,7 +86,7 @@ def _(stl):
return sat_comp return sat_comp
@pointwise_sat.register(stl.LinEq) @pointwise_satf.register(stl.LinEq)
def _(stl): def _(stl):
op = op_lookup[stl.op] op = op_lookup[stl.op]
def sat_comp(x, t): def sat_comp(x, t):

View file

@ -21,5 +21,5 @@ class TestSTLRobustness(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9) @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
def test_stl(self, phi_str, r): def test_stl(self, phi_str, r):
phi = stl.parse(phi_str) phi = stl.parse(phi_str)
stl_eval = stl.fastboolean_eval.pointwise_sat(phi) stl_eval = stl.fastboolean_eval.pointwise_satf(phi)
self.assertEqual(stl_eval(x, [0]), r) self.assertEqual(stl_eval(x, [0]), r)