fix bugs in binary search (still broken)

This commit is contained in:
Marcell Vazquez-Chanlatte 2016-10-09 22:34:28 -07:00
parent 28f755edc5
commit c6810d8d75
2 changed files with 19 additions and 19 deletions

View file

@ -69,30 +69,31 @@ def eval_term(x, t):
def binsearch(stleval, *, tol=1e-3, lo, hi, polarity): def binsearch(stleval, *, tol=1e-3, lo, hi, polarity):
"""Only run search if tightest robustness was positive.""" """Only run search if tightest robustness was positive."""
# Only check low since hi taken care of by precondition. # Only check low since hi taken care of by precondition.
r = stleval(lo) # TODO: allow for different polarities
rL, rH = stleval(lo), stleval(hi)
# Early termination via bounds checks
posL, posH = rL > 0, rH > 0
if polarity and posL:
return lo
elif not polarity and posH:
return hi
# TODO: early termination by bounds checks while hi - lo > tol:
mid = lo
if abs(r) < tol:
return r, mid
while abs(r) > tol and hi - lo > tol:
mid = lo + (hi - lo) / 2 mid = lo + (hi - lo) / 2
r = stleval(mid) r = stleval(mid)
if polarity: # swap direction if not polarity: # swap direction
r *= -1 r *= -1
if r < 0: if r < 0:
lo, hi = mid, hi lo, hi = mid, hi
else: else:
lo, hi = lo, mid lo, hi = lo, mid
return r, mid return mid
def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3): def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3):
val = {var: (ranges[var][0] if polarity[var] else ranges[var][1]) for var in order} val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order}
# TODO: evaluate top paramater # TODO: evaluate top paramater
p_lens = param_lens(stl) p_lens = param_lens(stl)
def stleval_fact(var, val): def stleval_fact(var, val):
l = lens(val)[var] l = lens(val)[var]
return lambda p: pointwise_robustness(set_params(stl, l.set(p)))(x, 0) return lambda p: pointwise_robustness(set_params(stl, l.set(p)))(x, 0)
@ -100,7 +101,7 @@ def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3):
for var in order: for var in order:
stleval = stleval_fact(var, val) stleval = stleval_fact(var, val)
lo, hi = ranges[var] lo, hi = ranges[var]
_, param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var]) param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var])
val[var] = param val[var] = param
return val return val

View file

@ -7,24 +7,23 @@ from sympy import Symbol
oo = float('inf') oo = float('inf')
ex1 = ("A > a?", ("a?",), {"a?": (0, 10)}, {"a?": True}, {"a?": 1}) ex1 = ("A > a?", ("a?",), {"a?": (0, 10)}, {"a?": False}, {"a?": 1})
ex1 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 10)}, ex2 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 5)},
{"a?": True, "b?": False}, {"a?": 4, "b?": 0.2}) {"a?": False, "b?": True}, {"a?": 4, "b?": 0.2})
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2], x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
columns=["A", "B"]) columns=["A", "B"])
class TestSTLRobustness(unittest.TestCase): class TestSTLRobustness(unittest.TestCase):
@params(ex1) @params(ex1, ex2)
def test_lex_synth(self, phi_str, order, ranges, polarity, val): def test_lex_synth(self, phi_str, order, ranges, polarity, val):
phi = stl.parse(phi_str) phi = stl.parse(phi_str)
val2 = stl.robustness.lex_param_project( val2 = stl.robustness.lex_param_project(
phi, x, order=order, ranges=ranges, polarity=polarity) phi, x, order=order, ranges=ranges, polarity=polarity)
phi = stl.robustness.set_params(phi, val) phi = stl.robustness.set_params(phi, val2)
stl_eval = stl.robustness.pointwise_robustness(phi) stl_eval = stl.robustness.pointwise_robustness(phi)
self.assertAlmostEqual(stl_eval(x, 0), 0) self.assertAlmostEqual(stl_eval(x, 0), 0, delta=0.01)
for var in order: for var in order:
self.assertAlmostEqual(val2[var], val[var], delta=0.01) self.assertAlmostEqual(val2[var], val[var], delta=0.01)