From da669f088dd30626a93f20b415d4f7b3a74cde45 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Sun, 9 Oct 2016 23:42:22 -0700 Subject: [PATCH] fix bug in robustness calculation + move synth for it's on module --- robustness.py | 47 ++-------------------------------------------- synth.py | 45 ++++++++++++++++++++++++++++++++++++++++++++ test_robustness.py | 4 +++- test_synth.py | 14 ++++++++------ 4 files changed, 58 insertions(+), 52 deletions(-) create mode 100644 synth.py diff --git a/robustness.py b/robustness.py index 86d7842..6faac99 100644 --- a/robustness.py +++ b/robustness.py @@ -4,7 +4,6 @@ from operator import sub, add from lenses import lens import stl.ast -from stl.utils import set_params, param_lens oo = float('inf') @@ -45,8 +44,8 @@ def _(stl): op_lookup = { ">": sub, ">=": sub, - "<": add, - "<=": add, + "<": lambda x, y: sub(y, x), + "<=": lambda x, y: sub(y, x), "=": lambda a, b: -abs(a - b), } @@ -64,45 +63,3 @@ def eval_terms(lineq, x, t): def eval_term(x, t): return lambda term: term.coeff*x[term.id.name][t] - - -def binsearch(stleval, *, tol=1e-3, lo, hi, polarity): - """Only run search if tightest robustness was positive.""" - # Only check low since hi taken care of by precondition. - # TODO: allow for different polarities - rL, rH = stleval(lo), stleval(hi) - # Early termination via bounds checks - posL, posH = rL > 0, rH > 0 - if polarity and posL: - return lo - elif not polarity and posH: - return hi - - while hi - lo > tol: - mid = lo + (hi - lo) / 2 - r = stleval(mid) - if not polarity: # swap direction - r *= -1 - if r < 0: - lo, hi = mid, hi - else: - lo, hi = lo, mid - - # Want satisifiable formula - return hi if polarity else lo - - -def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3): - val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order} - p_lens = param_lens(stl) - def stleval_fact(var, val): - l = lens(val)[var] - return lambda p: pointwise_robustness(set_params(stl, l.set(p)))(x, 0) - - for var in order: - stleval = stleval_fact(var, val) - lo, hi = ranges[var] - param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var]) - val[var] = param - - return val diff --git a/synth.py b/synth.py new file mode 100644 index 0000000..5ad0804 --- /dev/null +++ b/synth.py @@ -0,0 +1,45 @@ +from stl.utils import set_params, param_lens +from stl.robustness import pointwise_robustness + +from lenses import lens + +def binsearch(stleval, *, tol=1e-3, lo, hi, polarity): + """Only run search if tightest robustness was positive.""" + # Only check low since hi taken care of by precondition. + # TODO: allow for different polarities + rL, rH = stleval(lo), stleval(hi) + # Early termination via bounds checks + posL, posH = rL > 0, rH > 0 + if polarity and posL: + return lo + elif not polarity and posH: + return hi + + while hi - lo > tol: + mid = lo + (hi - lo) / 2 + r = stleval(mid) + if not polarity: # swap direction + r *= -1 + if r < 0: + lo, hi = mid, hi + else: + lo, hi = lo, mid + + # Want satisifiable formula + return hi if polarity else lo + + +def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3): + val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order} + p_lens = param_lens(stl) + def stleval_fact(var, val): + l = lens(val)[var] + return lambda p: pointwise_robustness(set_params(stl, l.set(p)))(x, 0) + + for var in order: + stleval = stleval_fact(var, val) + lo, hi = ranges[var] + param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var]) + val[var] = param + + return val diff --git a/test_robustness.py b/test_robustness.py index d803ea0..4c5db60 100644 --- a/test_robustness.py +++ b/test_robustness.py @@ -11,13 +11,15 @@ ex1 = ("2*A > 3", -1) ex2 = ("F[0, 1](2*A > 3)", 5) ex3 = ("F[1, 0](2*A > 3)", -oo) ex4 = ("G[1, 0](2*A > 3)", oo) +ex5 = ("(A < 0)", -1) +ex6 = ("G[0, 0.1](A < 0)", -1) x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2], columns=["A", "B"]) class TestSTLRobustness(unittest.TestCase): - @params(ex1, ex2, ex3, ex4) + @params(ex1, ex2, ex3, ex4, ex5, ex6) def test_stl(self, phi_str, r): phi = stl.parse(phi_str) stl_eval = stl.robustness.pointwise_robustness(phi) diff --git a/test_synth.py b/test_synth.py index 7f496e9..02444a2 100644 --- a/test_synth.py +++ b/test_synth.py @@ -1,5 +1,6 @@ import stl import stl.robustness +import stl.synth import pandas as pd from nose2.tools import params import unittest @@ -13,9 +14,10 @@ ex2 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 5)}, ex3 = ("F[0, b?](A < 0)", ("b?",), {"b?": (0, 5)}, {"b?": True}, {"b?": 5}) ex4 = ("G[0, b?](A < 0)", ("b?",), {"b?": (0.1, 5)}, - {"b?": True}, {"b?": 0.1}) + {"b?": False}, {"b?": 0.1}) ex5 = ("F[0, b?](A > 0)", ("b?",), {"b?": (0.1, 5)}, {"b?": True}, {"b?": 0.1}) + x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2], columns=["A", "B"]) @@ -24,14 +26,14 @@ class TestSTLRobustness(unittest.TestCase): @params(ex1, ex2, ex3, ex4, ex5) def test_lex_synth(self, phi_str, order, ranges, polarity, val): phi = stl.parse(phi_str) - val2 = stl.robustness.lex_param_project( + val2 = stl.synth.lex_param_project( phi, x, order=order, ranges=ranges, polarity=polarity) - phi = stl.robustness.set_params(phi, val2) - phi2 = stl.robustness.set_params(phi, val) + phi2 = stl.utils.set_params(phi, val2) + phi3 = stl.utils.set_params(phi, val) - stl_eval = stl.robustness.pointwise_robustness(phi) - stl_eval2 = stl.robustness.pointwise_robustness(phi2) + stl_eval = stl.robustness.pointwise_robustness(phi2) + stl_eval2 = stl.robustness.pointwise_robustness(phi3) # check that the robustnesses are almost the same self.assertAlmostEqual(stl_eval(x, 0), stl_eval2(x, 0), delta=0.01)