From 42efc69556e8f177adcfa7eac46ad9549fd6dd70 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Sun, 9 Oct 2016 19:48:26 -0700 Subject: [PATCH] implement smoketest robustness checks --- robustness.py | 5 +++-- test_parser.py | 12 ------------ test_robustness.py | 24 ++++++++++++++++++++++++ 3 files changed, 27 insertions(+), 14 deletions(-) create mode 100644 test_robustness.py diff --git a/robustness.py b/robustness.py index 77943e7..e875ffb 100644 --- a/robustness.py +++ b/robustness.py @@ -6,6 +6,7 @@ from lenses import lens import stl.ast from stl.utils import set_params, param_lens +oo = float('inf') @singledispatch def pointwise_robustness(stl): @@ -26,14 +27,14 @@ def _(stl): def _(stl): lo, hi = stl.interval return lambda x, t: max((pointwise_robustness(stl.arg)(x, t + t2) - for t2 in x[lo:hi].index), default=float('inf')) + for t2 in x[lo:hi].index), default=-oo) @pointwise_robustness.register(stl.G) def _(stl): lo, hi = stl.interval return lambda x, t: min((pointwise_robustness(stl.arg)(x, t + t2) - for t2 in x[lo:hi].index), default=-float('inf')) + for t2 in x[lo:hi].index), default=oo) @pointwise_robustness.register(stl.Neg) diff --git a/test_parser.py b/test_parser.py index 0426df5..b21b258 100644 --- a/test_parser.py +++ b/test_parser.py @@ -1,16 +1,9 @@ # -*- coding: utf-8 -*- import stl -from blustl.game import from_yaml from nose2.tools import params import unittest from sympy import Symbol -from glob import glob - -def main(): - with open('examples/example1.stl', 'r') as f: - print(from_yaml(f)) - ex1 = ('x1 > 2', stl.LinEq( (stl.Var(1, Symbol("x1"), stl.ast.t_sym),), ">", @@ -23,12 +16,7 @@ ex3 = ('□[2,3]◇[0,1](x1 > 2)', stl.G(i2, ex2[1])) ex4 = ('(x1 > 2) or ((x1 > 2) or (x1 > 2))', stl.Or((ex1[1], ex1[1], ex1[1]))) -example_ymls = glob('examples/*') - class TestSTLParser(unittest.TestCase): @params(ex1, ex2, ex3, ex4) def test_stl(self, phi_str, phi): self.assertEqual(stl.parse(phi_str), phi) - - - diff --git a/test_robustness.py b/test_robustness.py new file mode 100644 index 0000000..d803ea0 --- /dev/null +++ b/test_robustness.py @@ -0,0 +1,24 @@ +import stl +import stl.robustness +import pandas as pd +from nose2.tools import params +import unittest +from sympy import Symbol + +oo = float('inf') + +ex1 = ("2*A > 3", -1) +ex2 = ("F[0, 1](2*A > 3)", 5) +ex3 = ("F[1, 0](2*A > 3)", -oo) +ex4 = ("G[1, 0](2*A > 3)", oo) +x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2], + columns=["A", "B"]) + + + +class TestSTLRobustness(unittest.TestCase): + @params(ex1, ex2, ex3, ex4) + def test_stl(self, phi_str, r): + phi = stl.parse(phi_str) + stl_eval = stl.robustness.pointwise_robustness(phi) + self.assertEqual(stl_eval(x, 0), r)