Fast boolean with eval
This commit is contained in:
parent
d5f38e27ed
commit
8cae363e61
3 changed files with 15 additions and 14 deletions
|
|
@ -4,3 +4,4 @@ from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var
|
||||||
from stl.parser import parse
|
from stl.parser import parse
|
||||||
from stl.synth import lex_param_project
|
from stl.synth import lex_param_project
|
||||||
from stl.boolean_eval import pointwise_sat
|
from stl.boolean_eval import pointwise_sat
|
||||||
|
from stl.fastboolean_eval import pointwise_satf
|
||||||
|
|
|
||||||
|
|
@ -13,59 +13,59 @@ from bitarray import bitarray
|
||||||
import stl.ast
|
import stl.ast
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def pointwise_sat(stl):
|
def pointwise_satf(stl):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@pointwise_sat.register(stl.Or)
|
@pointwise_satf.register(stl.Or)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
def sat_comp(x,t):
|
def sat_comp(x,t):
|
||||||
sat = bitarray(len(t))
|
sat = bitarray(len(t))
|
||||||
for arg in stl.args:
|
for arg in stl.args:
|
||||||
sat = pointwise_sat(arg)(x, t) | sat
|
sat = pointwise_satf(arg)(x, t) | sat
|
||||||
return sat
|
return sat
|
||||||
return sat_comp
|
return sat_comp
|
||||||
|
|
||||||
|
|
||||||
@pointwise_sat.register(stl.And)
|
@pointwise_satf.register(stl.And)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
def sat_comp(x,t):
|
def sat_comp(x,t):
|
||||||
sat = bitarray(len(t))
|
sat = bitarray(len(t))
|
||||||
sat.setall('True')
|
sat.setall('True')
|
||||||
for arg in stl.args:
|
for arg in stl.args:
|
||||||
sat = pointwise_sat(arg)(x, t) & sat
|
sat = pointwise_satf(arg)(x, t) & sat
|
||||||
return sat
|
return sat
|
||||||
return sat_comp
|
return sat_comp
|
||||||
|
|
||||||
|
|
||||||
@pointwise_sat.register(stl.F)
|
@pointwise_satf.register(stl.F)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
lo, hi = stl.interval
|
lo, hi = stl.interval
|
||||||
def sat_comp(x,t):
|
def sat_comp(x,t):
|
||||||
sat = bitarray()
|
sat = bitarray()
|
||||||
for tau in t:
|
for tau in t:
|
||||||
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
||||||
sat.append((pointwise_sat(stl.arg)(x, tau_t)).count() > 0)
|
sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0)
|
||||||
return sat
|
return sat
|
||||||
return sat_comp
|
return sat_comp
|
||||||
|
|
||||||
|
|
||||||
@pointwise_sat.register(stl.G)
|
@pointwise_satf.register(stl.G)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
lo, hi = stl.interval
|
lo, hi = stl.interval
|
||||||
def sat_comp(x,t):
|
def sat_comp(x,t):
|
||||||
sat = bitarray()
|
sat = bitarray()
|
||||||
for tau in t:
|
for tau in t:
|
||||||
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
||||||
point_sat = pointwise_sat(stl.arg)(x, tau_t)
|
point_sat = pointwise_satf(stl.arg)(x, tau_t)
|
||||||
sat.append(point_sat.count() == point_sat.length())
|
sat.append(point_sat.count() == point_sat.length())
|
||||||
return sat
|
return sat
|
||||||
return sat_comp
|
return sat_comp
|
||||||
|
|
||||||
|
|
||||||
@pointwise_sat.register(stl.Neg)
|
@pointwise_satf.register(stl.Neg)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
return lambda x,t: ~pointwise_sat(arg)(x, t)
|
return lambda x,t: ~pointwise_satf(arg)(x, t)
|
||||||
|
|
||||||
|
|
||||||
op_lookup = {
|
op_lookup = {
|
||||||
|
|
@ -77,7 +77,7 @@ op_lookup = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@pointwise_sat.register(stl.AtomicPred)
|
@pointwise_satf.register(stl.AtomicPred)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
def sat_comp(x, t):
|
def sat_comp(x, t):
|
||||||
sat = bitarray()
|
sat = bitarray()
|
||||||
|
|
@ -86,7 +86,7 @@ def _(stl):
|
||||||
return sat_comp
|
return sat_comp
|
||||||
|
|
||||||
|
|
||||||
@pointwise_sat.register(stl.LinEq)
|
@pointwise_satf.register(stl.LinEq)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
op = op_lookup[stl.op]
|
op = op_lookup[stl.op]
|
||||||
def sat_comp(x, t):
|
def sat_comp(x, t):
|
||||||
|
|
|
||||||
|
|
@ -21,5 +21,5 @@ class TestSTLRobustness(unittest.TestCase):
|
||||||
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
|
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
|
||||||
def test_stl(self, phi_str, r):
|
def test_stl(self, phi_str, r):
|
||||||
phi = stl.parse(phi_str)
|
phi = stl.parse(phi_str)
|
||||||
stl_eval = stl.fastboolean_eval.pointwise_sat(phi)
|
stl_eval = stl.fastboolean_eval.pointwise_satf(phi)
|
||||||
self.assertEqual(stl_eval(x, [0]), r)
|
self.assertEqual(stl_eval(x, [0]), r)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue