Fast boolean with eval
This commit is contained in:
parent
d5f38e27ed
commit
8cae363e61
3 changed files with 15 additions and 14 deletions
|
|
@ -4,3 +4,4 @@ from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var
|
|||
from stl.parser import parse
|
||||
from stl.synth import lex_param_project
|
||||
from stl.boolean_eval import pointwise_sat
|
||||
from stl.fastboolean_eval import pointwise_satf
|
||||
|
|
|
|||
|
|
@ -13,59 +13,59 @@ from bitarray import bitarray
|
|||
import stl.ast
|
||||
|
||||
@singledispatch
|
||||
def pointwise_sat(stl):
|
||||
def pointwise_satf(stl):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.Or)
|
||||
@pointwise_satf.register(stl.Or)
|
||||
def _(stl):
|
||||
def sat_comp(x,t):
|
||||
sat = bitarray(len(t))
|
||||
for arg in stl.args:
|
||||
sat = pointwise_sat(arg)(x, t) | sat
|
||||
sat = pointwise_satf(arg)(x, t) | sat
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.And)
|
||||
@pointwise_satf.register(stl.And)
|
||||
def _(stl):
|
||||
def sat_comp(x,t):
|
||||
sat = bitarray(len(t))
|
||||
sat.setall('True')
|
||||
for arg in stl.args:
|
||||
sat = pointwise_sat(arg)(x, t) & sat
|
||||
sat = pointwise_satf(arg)(x, t) & sat
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.F)
|
||||
@pointwise_satf.register(stl.F)
|
||||
def _(stl):
|
||||
lo, hi = stl.interval
|
||||
def sat_comp(x,t):
|
||||
sat = bitarray()
|
||||
for tau in t:
|
||||
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
||||
sat.append((pointwise_sat(stl.arg)(x, tau_t)).count() > 0)
|
||||
sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0)
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.G)
|
||||
@pointwise_satf.register(stl.G)
|
||||
def _(stl):
|
||||
lo, hi = stl.interval
|
||||
def sat_comp(x,t):
|
||||
sat = bitarray()
|
||||
for tau in t:
|
||||
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
||||
point_sat = pointwise_sat(stl.arg)(x, tau_t)
|
||||
point_sat = pointwise_satf(stl.arg)(x, tau_t)
|
||||
sat.append(point_sat.count() == point_sat.length())
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.Neg)
|
||||
@pointwise_satf.register(stl.Neg)
|
||||
def _(stl):
|
||||
return lambda x,t: ~pointwise_sat(arg)(x, t)
|
||||
return lambda x,t: ~pointwise_satf(arg)(x, t)
|
||||
|
||||
|
||||
op_lookup = {
|
||||
|
|
@ -77,7 +77,7 @@ op_lookup = {
|
|||
}
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.AtomicPred)
|
||||
@pointwise_satf.register(stl.AtomicPred)
|
||||
def _(stl):
|
||||
def sat_comp(x, t):
|
||||
sat = bitarray()
|
||||
|
|
@ -86,7 +86,7 @@ def _(stl):
|
|||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.LinEq)
|
||||
@pointwise_satf.register(stl.LinEq)
|
||||
def _(stl):
|
||||
op = op_lookup[stl.op]
|
||||
def sat_comp(x, t):
|
||||
|
|
|
|||
|
|
@ -21,5 +21,5 @@ class TestSTLRobustness(unittest.TestCase):
|
|||
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
|
||||
def test_stl(self, phi_str, r):
|
||||
phi = stl.parse(phi_str)
|
||||
stl_eval = stl.fastboolean_eval.pointwise_sat(phi)
|
||||
stl_eval = stl.fastboolean_eval.pointwise_satf(phi)
|
||||
self.assertEqual(stl_eval(x, [0]), r)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue