Merge pull request #1 from shromonag/master

Looks good. Going to merge and update setup.py to have bitarray as a dependency
This commit is contained in:
Marcell Vazquez-Chanlatte 2016-12-03 12:33:06 -08:00 committed by GitHub
commit c41231e964
4 changed files with 158 additions and 0 deletions

View file

@ -4,3 +4,4 @@ from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var
from stl.parser import parse
from stl.synth import lex_param_project
from stl.boolean_eval import pointwise_sat
from stl.fastboolean_eval import pointwise_satf

106
stl/fastboolean_eval.py Normal file
View file

@ -0,0 +1,106 @@
# TODO: figure out how to deduplicate this with robustness
# - Abstract as working on distributive lattice
from functools import singledispatch
import operator as op
import numpy as np
import sympy as smp
from lenses import lens
import gmpy2 as gp
from bitarray import bitarray
import stl.ast
@singledispatch
def pointwise_satf(stl):
raise NotImplementedError
@pointwise_satf.register(stl.Or)
def _(stl):
def sat_comp(x,t):
sat = bitarray(len(t))
for arg in stl.args:
sat = pointwise_satf(arg)(x, t) | sat
return sat
return sat_comp
@pointwise_satf.register(stl.And)
def _(stl):
def sat_comp(x,t):
sat = bitarray(len(t))
sat.setall('True')
for arg in stl.args:
sat = pointwise_satf(arg)(x, t) & sat
return sat
return sat_comp
@pointwise_satf.register(stl.F)
def _(stl):
lo, hi = stl.interval
def sat_comp(x,t):
sat = bitarray()
for tau in t:
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0)
return sat
return sat_comp
@pointwise_satf.register(stl.G)
def _(stl):
lo, hi = stl.interval
def sat_comp(x,t):
sat = bitarray()
for tau in t:
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
point_sat = pointwise_satf(stl.arg)(x, tau_t)
sat.append(point_sat.count() == point_sat.length())
return sat
return sat_comp
@pointwise_satf.register(stl.Neg)
def _(stl):
return lambda x,t: ~pointwise_satf(arg)(x, t)
op_lookup = {
">": op.gt,
">=": op.ge,
"<": op.lt,
"<=": op.le,
"=": op.eq,
}
@pointwise_satf.register(stl.AtomicPred)
def _(stl):
def sat_comp(x, t):
sat = bitarray()
[sat.append(x[stl.id][tau]) for tau in t]
return sat
return sat_comp
@pointwise_satf.register(stl.LinEq)
def _(stl):
op = op_lookup[stl.op]
def sat_comp(x, t):
sat = bitarray()
[sat.append(op(eval_terms(stl, x, tau), stl.const)) for tau in t]
return sat
return sat_comp
def eval_terms(lineq, x, t):
psi = lens(lineq).terms.each_().modify(eval_term(x, t))
return sum(psi.terms)
def eval_term(x, t):
# TODO(lift interpolation much higher)
return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name])

View file

@ -0,0 +1,25 @@
import stl
import stl.fastboolean_eval
import pandas as pd
from nose2.tools import params
import unittest
from sympy import Symbol
ex1 = ("2*A > 3", False)
ex2 = ("F[0, 1](2*A > 3)", True)
ex3 = ("F[1, 0](2*A > 3)", False)
ex4 = ("G[1, 0](2*A > 3)", True)
ex5 = ("(A < 0)", False)
ex6 = ("G[0, 0.1](A < 0)", False)
ex7 = ("G[0, 0.1](C)", True)
ex8 = ("G[0, 0.2](C)", False)
ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True)
x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2],
columns=["A", "B", "C"])
class TestSTLRobustness(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9)
def test_stl(self, phi_str, r):
phi = stl.parse(phi_str)
stl_eval = stl.fastboolean_eval.pointwise_satf(phi)
self.assertEqual(stl_eval(x, [0]), r)

26
test_boolean.py Normal file
View file

@ -0,0 +1,26 @@
import stl
import stl.fastboolean_eval
import pandas as pd
from nose2.tools import params
import unittest
from sympy import Symbol
ex1 = ("2*A > 3", False)
ex2 = ("F[0, 1](2*A > 3)", True)
ex3 = ("F[1, 0](2*A > 3)", False)
ex4 = ("G[1, 0](2*A > 3)", True)
ex5 = ("(A < 0)", False)
ex6 = ("G[0, 0.1](A < 0)", False)
ex7 = ("G[0, 0.1](C)", True)
ex8 = ("G[0, 0.2](C)", False)
ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True)
x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2],
columns=["A", "B", "C"])
tests = [ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9]
for test in tests:
phi = stl.parse(test[0])
print(phi)
stl_eval = stl.fastboolean_eval.pointwise_sat(phi)
print(stl_eval(x, [0]))