From c6810d8d758712735467827774b2e7b2c0a4ac40 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Sun, 9 Oct 2016 22:34:28 -0700 Subject: [PATCH] fix bugs in binary search (still broken) --- robustness.py | 25 +++++++++++++------------ test_synth.py | 13 ++++++------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/robustness.py b/robustness.py index 09ddda2..ccbf438 100644 --- a/robustness.py +++ b/robustness.py @@ -69,30 +69,31 @@ def eval_term(x, t): def binsearch(stleval, *, tol=1e-3, lo, hi, polarity): """Only run search if tightest robustness was positive.""" # Only check low since hi taken care of by precondition. - r = stleval(lo) + # TODO: allow for different polarities + rL, rH = stleval(lo), stleval(hi) + # Early termination via bounds checks + posL, posH = rL > 0, rH > 0 + if polarity and posL: + return lo + elif not polarity and posH: + return hi - # TODO: early termination by bounds checks - mid = lo - if abs(r) < tol: - return r, mid - - while abs(r) > tol and hi - lo > tol: + while hi - lo > tol: mid = lo + (hi - lo) / 2 r = stleval(mid) - if polarity: # swap direction + if not polarity: # swap direction r *= -1 if r < 0: lo, hi = mid, hi else: lo, hi = lo, mid - return r, mid + return mid def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3): - val = {var: (ranges[var][0] if polarity[var] else ranges[var][1]) for var in order} + val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order} # TODO: evaluate top paramater p_lens = param_lens(stl) - def stleval_fact(var, val): l = lens(val)[var] return lambda p: pointwise_robustness(set_params(stl, l.set(p)))(x, 0) @@ -100,7 +101,7 @@ def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3): for var in order: stleval = stleval_fact(var, val) lo, hi = ranges[var] - _, param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var]) + param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var]) val[var] = param return val diff --git a/test_synth.py b/test_synth.py index a5a7f80..c5055db 100644 --- a/test_synth.py +++ b/test_synth.py @@ -7,24 +7,23 @@ from sympy import Symbol oo = float('inf') -ex1 = ("A > a?", ("a?",), {"a?": (0, 10)}, {"a?": True}, {"a?": 1}) -ex1 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 10)}, - {"a?": True, "b?": False}, {"a?": 4, "b?": 0.2}) +ex1 = ("A > a?", ("a?",), {"a?": (0, 10)}, {"a?": False}, {"a?": 1}) +ex2 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 5)}, + {"a?": False, "b?": True}, {"a?": 4, "b?": 0.2}) x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2], columns=["A", "B"]) class TestSTLRobustness(unittest.TestCase): - @params(ex1) + @params(ex1, ex2) def test_lex_synth(self, phi_str, order, ranges, polarity, val): phi = stl.parse(phi_str) val2 = stl.robustness.lex_param_project( phi, x, order=order, ranges=ranges, polarity=polarity) - phi = stl.robustness.set_params(phi, val) + phi = stl.robustness.set_params(phi, val2) stl_eval = stl.robustness.pointwise_robustness(phi) - self.assertAlmostEqual(stl_eval(x, 0), 0) - + self.assertAlmostEqual(stl_eval(x, 0), 0, delta=0.01) for var in order: self.assertAlmostEqual(val2[var], val[var], delta=0.01)