Merge pull request #1 from shromonag/master
Looks good. Going to merge and update setup.py to have bitarray as a dependency
This commit is contained in:
commit
c41231e964
4 changed files with 158 additions and 0 deletions
|
|
@ -4,3 +4,4 @@ from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var
|
|||
from stl.parser import parse
|
||||
from stl.synth import lex_param_project
|
||||
from stl.boolean_eval import pointwise_sat
|
||||
from stl.fastboolean_eval import pointwise_satf
|
||||
|
|
|
|||
106
stl/fastboolean_eval.py
Normal file
106
stl/fastboolean_eval.py
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
# TODO: figure out how to deduplicate this with robustness
|
||||
# - Abstract as working on distributive lattice
|
||||
|
||||
from functools import singledispatch
|
||||
import operator as op
|
||||
|
||||
import numpy as np
|
||||
import sympy as smp
|
||||
from lenses import lens
|
||||
import gmpy2 as gp
|
||||
from bitarray import bitarray
|
||||
|
||||
import stl.ast
|
||||
|
||||
@singledispatch
|
||||
def pointwise_satf(stl):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.Or)
|
||||
def _(stl):
|
||||
def sat_comp(x,t):
|
||||
sat = bitarray(len(t))
|
||||
for arg in stl.args:
|
||||
sat = pointwise_satf(arg)(x, t) | sat
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.And)
|
||||
def _(stl):
|
||||
def sat_comp(x,t):
|
||||
sat = bitarray(len(t))
|
||||
sat.setall('True')
|
||||
for arg in stl.args:
|
||||
sat = pointwise_satf(arg)(x, t) & sat
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.F)
|
||||
def _(stl):
|
||||
lo, hi = stl.interval
|
||||
def sat_comp(x,t):
|
||||
sat = bitarray()
|
||||
for tau in t:
|
||||
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
||||
sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0)
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.G)
|
||||
def _(stl):
|
||||
lo, hi = stl.interval
|
||||
def sat_comp(x,t):
|
||||
sat = bitarray()
|
||||
for tau in t:
|
||||
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
||||
point_sat = pointwise_satf(stl.arg)(x, tau_t)
|
||||
sat.append(point_sat.count() == point_sat.length())
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.Neg)
|
||||
def _(stl):
|
||||
return lambda x,t: ~pointwise_satf(arg)(x, t)
|
||||
|
||||
|
||||
op_lookup = {
|
||||
">": op.gt,
|
||||
">=": op.ge,
|
||||
"<": op.lt,
|
||||
"<=": op.le,
|
||||
"=": op.eq,
|
||||
}
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.AtomicPred)
|
||||
def _(stl):
|
||||
def sat_comp(x, t):
|
||||
sat = bitarray()
|
||||
[sat.append(x[stl.id][tau]) for tau in t]
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.LinEq)
|
||||
def _(stl):
|
||||
op = op_lookup[stl.op]
|
||||
def sat_comp(x, t):
|
||||
sat = bitarray()
|
||||
[sat.append(op(eval_terms(stl, x, tau), stl.const)) for tau in t]
|
||||
return sat
|
||||
return sat_comp
|
||||
|
||||
|
||||
def eval_terms(lineq, x, t):
|
||||
psi = lens(lineq).terms.each_().modify(eval_term(x, t))
|
||||
return sum(psi.terms)
|
||||
|
||||
|
||||
def eval_term(x, t):
|
||||
# TODO(lift interpolation much higher)
|
||||
return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name])
|
||||
25
stl/test_fastboolean_eval.py
Normal file
25
stl/test_fastboolean_eval.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import stl
|
||||
import stl.fastboolean_eval
|
||||
import pandas as pd
|
||||
from nose2.tools import params
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
ex1 = ("2*A > 3", False)
|
||||
ex2 = ("F[0, 1](2*A > 3)", True)
|
||||
ex3 = ("F[1, 0](2*A > 3)", False)
|
||||
ex4 = ("G[1, 0](2*A > 3)", True)
|
||||
ex5 = ("(A < 0)", False)
|
||||
ex6 = ("G[0, 0.1](A < 0)", False)
|
||||
ex7 = ("G[0, 0.1](C)", True)
|
||||
ex8 = ("G[0, 0.2](C)", False)
|
||||
ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True)
|
||||
x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2],
|
||||
columns=["A", "B", "C"])
|
||||
|
||||
class TestSTLRobustness(unittest.TestCase):
|
||||
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
|
||||
def test_stl(self, phi_str, r):
|
||||
phi = stl.parse(phi_str)
|
||||
stl_eval = stl.fastboolean_eval.pointwise_satf(phi)
|
||||
self.assertEqual(stl_eval(x, [0]), r)
|
||||
26
test_boolean.py
Normal file
26
test_boolean.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import stl
|
||||
import stl.fastboolean_eval
|
||||
import pandas as pd
|
||||
from nose2.tools import params
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
ex1 = ("2*A > 3", False)
|
||||
ex2 = ("F[0, 1](2*A > 3)", True)
|
||||
ex3 = ("F[1, 0](2*A > 3)", False)
|
||||
ex4 = ("G[1, 0](2*A > 3)", True)
|
||||
ex5 = ("(A < 0)", False)
|
||||
ex6 = ("G[0, 0.1](A < 0)", False)
|
||||
ex7 = ("G[0, 0.1](C)", True)
|
||||
ex8 = ("G[0, 0.2](C)", False)
|
||||
ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True)
|
||||
x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2],
|
||||
columns=["A", "B", "C"])
|
||||
|
||||
tests = [ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9]
|
||||
for test in tests:
|
||||
phi = stl.parse(test[0])
|
||||
print(phi)
|
||||
stl_eval = stl.fastboolean_eval.pointwise_sat(phi)
|
||||
print(stl_eval(x, [0]))
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue