From 014110cf90917820cebd4ae85b2bdef94b890224 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Wed, 21 Dec 2016 19:18:01 -0800 Subject: [PATCH] sketch out tests for smooth robustness --- stl/test_robustness.py | 31 ++++++++++++++++++++++++++++++- stl/utils.py | 2 +- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/stl/test_robustness.py b/stl/test_robustness.py index 467db95..d2a515b 100644 --- a/stl/test_robustness.py +++ b/stl/test_robustness.py @@ -1,5 +1,7 @@ import stl +import stl.boolean_eval import stl.robustness +import stl.smooth_robustness import pandas as pd from nose2.tools import params import unittest @@ -19,7 +21,34 @@ x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2], class TestSTLRobustness(unittest.TestCase): @params(ex1, ex2, ex3, ex4, ex5, ex6) - def test_stl(self, phi_str, r): + def test_robustness_sign(self, phi_str, _): + phi = stl.parse(phi_str) + stl_eval = stl.boolean_eval.pointwise_sat(phi) + stl_eval2 = stl.robustness.pointwise_robustness(phi) + r = stl_eval2(x, 0) + assert (r == 0 or ((r > 0) == stl_eval(x, 0))) + + + @params(ex1, ex2, ex3, ex4, ex5, ex6) + def test_robustness_value(self, phi_str, r): phi = stl.parse(phi_str) stl_eval = stl.robustness.pointwise_robustness(phi) self.assertEqual(stl_eval(x, 0), r) + + + @params(ex1, ex2, ex3, ex4, ex5, ex6) + def test_eps_robustness(self, phi_str, r): + phi = stl.parse(phi_str) + r = stl.robustness.pointwise_robustness(phi)(x, 0) + lo, hi = stl.smooth_robustness.smooth_robustness(phi, L=1, eps=0.1) + # hi - lo <= eps + # lo <= r <= hi + raise NotImplementedError + + + @params(ex1, ex2, ex3, ex4, ex5, ex6) + def test_interval_polarity(self, phi_str, r): + phi = stl.parse(phi_str) + lo, hi = stl.smooth_robustness.smooth_robustness(phi, L=1, eps=0.1) + # hi - lo > 0 + raise NotImplementedError diff --git a/stl/utils.py b/stl/utils.py index d4e3b6b..1cdf605 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -66,7 +66,7 @@ def terms_lens(phi:"STL", bind=True) -> lens: def param_lens(phi): is_sym = lambda x: isinstance(x, sympy.Symbol) def focus_lens(leaf): - return [lens().const, lens().id] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] + return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens).filter_(is_sym)