implement smoketest robustness checks
This commit is contained in:
parent
fb617482bf
commit
42efc69556
3 changed files with 27 additions and 14 deletions
|
|
@ -6,6 +6,7 @@ from lenses import lens
|
||||||
import stl.ast
|
import stl.ast
|
||||||
from stl.utils import set_params, param_lens
|
from stl.utils import set_params, param_lens
|
||||||
|
|
||||||
|
oo = float('inf')
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def pointwise_robustness(stl):
|
def pointwise_robustness(stl):
|
||||||
|
|
@ -26,14 +27,14 @@ def _(stl):
|
||||||
def _(stl):
|
def _(stl):
|
||||||
lo, hi = stl.interval
|
lo, hi = stl.interval
|
||||||
return lambda x, t: max((pointwise_robustness(stl.arg)(x, t + t2)
|
return lambda x, t: max((pointwise_robustness(stl.arg)(x, t + t2)
|
||||||
for t2 in x[lo:hi].index), default=float('inf'))
|
for t2 in x[lo:hi].index), default=-oo)
|
||||||
|
|
||||||
|
|
||||||
@pointwise_robustness.register(stl.G)
|
@pointwise_robustness.register(stl.G)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
lo, hi = stl.interval
|
lo, hi = stl.interval
|
||||||
return lambda x, t: min((pointwise_robustness(stl.arg)(x, t + t2)
|
return lambda x, t: min((pointwise_robustness(stl.arg)(x, t + t2)
|
||||||
for t2 in x[lo:hi].index), default=-float('inf'))
|
for t2 in x[lo:hi].index), default=oo)
|
||||||
|
|
||||||
|
|
||||||
@pointwise_robustness.register(stl.Neg)
|
@pointwise_robustness.register(stl.Neg)
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,9 @@
|
||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
import stl
|
import stl
|
||||||
from blustl.game import from_yaml
|
|
||||||
from nose2.tools import params
|
from nose2.tools import params
|
||||||
import unittest
|
import unittest
|
||||||
from sympy import Symbol
|
from sympy import Symbol
|
||||||
|
|
||||||
from glob import glob
|
|
||||||
|
|
||||||
def main():
|
|
||||||
with open('examples/example1.stl', 'r') as f:
|
|
||||||
print(from_yaml(f))
|
|
||||||
|
|
||||||
ex1 = ('x1 > 2', stl.LinEq(
|
ex1 = ('x1 > 2', stl.LinEq(
|
||||||
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),),
|
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),),
|
||||||
">",
|
">",
|
||||||
|
|
@ -23,12 +16,7 @@ ex3 = ('□[2,3]◇[0,1](x1 > 2)', stl.G(i2, ex2[1]))
|
||||||
ex4 = ('(x1 > 2) or ((x1 > 2) or (x1 > 2))',
|
ex4 = ('(x1 > 2) or ((x1 > 2) or (x1 > 2))',
|
||||||
stl.Or((ex1[1], ex1[1], ex1[1])))
|
stl.Or((ex1[1], ex1[1], ex1[1])))
|
||||||
|
|
||||||
example_ymls = glob('examples/*')
|
|
||||||
|
|
||||||
class TestSTLParser(unittest.TestCase):
|
class TestSTLParser(unittest.TestCase):
|
||||||
@params(ex1, ex2, ex3, ex4)
|
@params(ex1, ex2, ex3, ex4)
|
||||||
def test_stl(self, phi_str, phi):
|
def test_stl(self, phi_str, phi):
|
||||||
self.assertEqual(stl.parse(phi_str), phi)
|
self.assertEqual(stl.parse(phi_str), phi)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
24
test_robustness.py
Normal file
24
test_robustness.py
Normal file
|
|
@ -0,0 +1,24 @@
|
||||||
|
import stl
|
||||||
|
import stl.robustness
|
||||||
|
import pandas as pd
|
||||||
|
from nose2.tools import params
|
||||||
|
import unittest
|
||||||
|
from sympy import Symbol
|
||||||
|
|
||||||
|
oo = float('inf')
|
||||||
|
|
||||||
|
ex1 = ("2*A > 3", -1)
|
||||||
|
ex2 = ("F[0, 1](2*A > 3)", 5)
|
||||||
|
ex3 = ("F[1, 0](2*A > 3)", -oo)
|
||||||
|
ex4 = ("G[1, 0](2*A > 3)", oo)
|
||||||
|
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
|
||||||
|
columns=["A", "B"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class TestSTLRobustness(unittest.TestCase):
|
||||||
|
@params(ex1, ex2, ex3, ex4)
|
||||||
|
def test_stl(self, phi_str, r):
|
||||||
|
phi = stl.parse(phi_str)
|
||||||
|
stl_eval = stl.robustness.pointwise_robustness(phi)
|
||||||
|
self.assertEqual(stl_eval(x, 0), r)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue