Fast boolean with eval

This commit is contained in:
Shromona MacBook 2016-12-03 01:21:00 -08:00
parent d5f38e27ed
commit 8cae363e61
3 changed files with 15 additions and 14 deletions

View file

@ -4,3 +4,4 @@ from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var
from stl.parser import parse
from stl.synth import lex_param_project
from stl.boolean_eval import pointwise_sat
from stl.fastboolean_eval import pointwise_satf

View file

@ -13,59 +13,59 @@ from bitarray import bitarray
import stl.ast
@singledispatch
def pointwise_sat(stl):
def pointwise_satf(stl):
raise NotImplementedError
@pointwise_sat.register(stl.Or)
@pointwise_satf.register(stl.Or)
def _(stl):
def sat_comp(x,t):
sat = bitarray(len(t))
for arg in stl.args:
sat = pointwise_sat(arg)(x, t) | sat
sat = pointwise_satf(arg)(x, t) | sat
return sat
return sat_comp
@pointwise_sat.register(stl.And)
@pointwise_satf.register(stl.And)
def _(stl):
def sat_comp(x,t):
sat = bitarray(len(t))
sat.setall('True')
for arg in stl.args:
sat = pointwise_sat(arg)(x, t) & sat
sat = pointwise_satf(arg)(x, t) & sat
return sat
return sat_comp
@pointwise_sat.register(stl.F)
@pointwise_satf.register(stl.F)
def _(stl):
lo, hi = stl.interval
def sat_comp(x,t):
sat = bitarray()
for tau in t:
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
sat.append((pointwise_sat(stl.arg)(x, tau_t)).count() > 0)
sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0)
return sat
return sat_comp
@pointwise_sat.register(stl.G)
@pointwise_satf.register(stl.G)
def _(stl):
lo, hi = stl.interval
def sat_comp(x,t):
sat = bitarray()
for tau in t:
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
point_sat = pointwise_sat(stl.arg)(x, tau_t)
point_sat = pointwise_satf(stl.arg)(x, tau_t)
sat.append(point_sat.count() == point_sat.length())
return sat
return sat_comp
@pointwise_sat.register(stl.Neg)
@pointwise_satf.register(stl.Neg)
def _(stl):
return lambda x,t: ~pointwise_sat(arg)(x, t)
return lambda x,t: ~pointwise_satf(arg)(x, t)
op_lookup = {
@ -77,7 +77,7 @@ op_lookup = {
}
@pointwise_sat.register(stl.AtomicPred)
@pointwise_satf.register(stl.AtomicPred)
def _(stl):
def sat_comp(x, t):
sat = bitarray()
@ -86,7 +86,7 @@ def _(stl):
return sat_comp
@pointwise_sat.register(stl.LinEq)
@pointwise_satf.register(stl.LinEq)
def _(stl):
op = op_lookup[stl.op]
def sat_comp(x, t):

View file

@ -21,5 +21,5 @@ class TestSTLRobustness(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
def test_stl(self, phi_str, r):
phi = stl.parse(phi_str)
stl_eval = stl.fastboolean_eval.pointwise_sat(phi)
stl_eval = stl.fastboolean_eval.pointwise_satf(phi)
self.assertEqual(stl_eval(x, [0]), r)