default to fastboolean eval

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-02-28 13:32:54 -08:00
parent a34e4c2b0e
commit a8e84a3761
4 changed files with 17 additions and 6 deletions

View file

@ -4,6 +4,5 @@ from stl.ast import dt_sym, t_sym, TOP, BOT
from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred
from stl.parser import parse from stl.parser import parse
from stl.synth import lex_param_project from stl.synth import lex_param_project
from stl.boolean_eval import pointwise_sat from stl.fastboolean_eval import pointwise_sat
from stl.fastboolean_eval import pointwise_satf
from stl.types import STL from stl.types import STL

View file

@ -40,7 +40,7 @@ def _(stl):
@pointwise_sat.register(stl.Neg) @pointwise_sat.register(stl.Neg)
def _(stl): def _(stl):
return lambda x, t: not pointwise_sat(arg)(x, t) return lambda x, t: not pointwise_sat(stl.arg)(x, t)
op_lookup = { op_lookup = {

View file

@ -6,6 +6,10 @@ from bitarray import bitarray
import stl.ast import stl.ast
from stl.boolean_eval import eval_terms, op_lookup from stl.boolean_eval import eval_terms, op_lookup
def pointwise_sat(stl):
f = pointwise_satf(stl)
return lambda x, t: bool(int(f(x, [t]).to01()))
@singledispatch @singledispatch
def pointwise_satf(stl): def pointwise_satf(stl):
raise NotImplementedError raise NotImplementedError
@ -57,7 +61,7 @@ def _(stl):
@pointwise_satf.register(stl.Neg) @pointwise_satf.register(stl.Neg)
def _(stl): def _(stl):
return lambda x,t: ~pointwise_satf(arg)(x, t) return lambda x,t: ~pointwise_satf(stl.arg)(x, t)
@pointwise_satf.register(stl.AtomicPred) @pointwise_satf.register(stl.AtomicPred)

View file

@ -23,12 +23,20 @@ class TestSTLEval(unittest.TestCase):
def test_eval(self, phi_str, r): def test_eval(self, phi_str, r):
phi = stl.parse(phi_str) phi = stl.parse(phi_str)
stl_eval = stl.boolean_eval.pointwise_sat(phi) stl_eval = stl.boolean_eval.pointwise_sat(phi)
stl_eval2 = stl.boolean_eval.pointwise_sat(~phi)
self.assertEqual(stl_eval(x, 0), r) self.assertEqual(stl_eval(x, 0), r)
self.assertEqual(stl_eval2(x, 0), not r)
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9) @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
def test_fasteval(self, phi_str, _): def test_fasteval(self, phi_str, _):
phi = stl.parse(phi_str) phi = stl.parse(phi_str)
stl_eval = stl.boolean_eval.pointwise_sat(phi) stl_eval = stl.boolean_eval.pointwise_sat(phi)
stl_evalf = stl.fastboolean_eval.pointwise_satf(phi) stl_evalf = stl.fastboolean_eval.pointwise_sat(phi)
self.assertEqual(bool(int(stl_evalf(x, [0]).to01())), stl_eval(x, 0)) stl_evalf2 = stl.fastboolean_eval.pointwise_sat(~phi)
b_slow = stl_eval(x, 0)
b_fast = stl_evalf(x, 0)
b_fast2 = stl_evalf2(x, 0)
self.assertEqual(b_slow, b_fast)
self.assertEqual(b_fast, not b_fast2)