fix bug in robustness calculation + move synth for it's on module
This commit is contained in:
parent
9b42d9bb57
commit
da669f088d
4 changed files with 58 additions and 52 deletions
|
|
@ -4,7 +4,6 @@ from operator import sub, add
|
||||||
from lenses import lens
|
from lenses import lens
|
||||||
|
|
||||||
import stl.ast
|
import stl.ast
|
||||||
from stl.utils import set_params, param_lens
|
|
||||||
|
|
||||||
oo = float('inf')
|
oo = float('inf')
|
||||||
|
|
||||||
|
|
@ -45,8 +44,8 @@ def _(stl):
|
||||||
op_lookup = {
|
op_lookup = {
|
||||||
">": sub,
|
">": sub,
|
||||||
">=": sub,
|
">=": sub,
|
||||||
"<": add,
|
"<": lambda x, y: sub(y, x),
|
||||||
"<=": add,
|
"<=": lambda x, y: sub(y, x),
|
||||||
"=": lambda a, b: -abs(a - b),
|
"=": lambda a, b: -abs(a - b),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -64,45 +63,3 @@ def eval_terms(lineq, x, t):
|
||||||
|
|
||||||
def eval_term(x, t):
|
def eval_term(x, t):
|
||||||
return lambda term: term.coeff*x[term.id.name][t]
|
return lambda term: term.coeff*x[term.id.name][t]
|
||||||
|
|
||||||
|
|
||||||
def binsearch(stleval, *, tol=1e-3, lo, hi, polarity):
|
|
||||||
"""Only run search if tightest robustness was positive."""
|
|
||||||
# Only check low since hi taken care of by precondition.
|
|
||||||
# TODO: allow for different polarities
|
|
||||||
rL, rH = stleval(lo), stleval(hi)
|
|
||||||
# Early termination via bounds checks
|
|
||||||
posL, posH = rL > 0, rH > 0
|
|
||||||
if polarity and posL:
|
|
||||||
return lo
|
|
||||||
elif not polarity and posH:
|
|
||||||
return hi
|
|
||||||
|
|
||||||
while hi - lo > tol:
|
|
||||||
mid = lo + (hi - lo) / 2
|
|
||||||
r = stleval(mid)
|
|
||||||
if not polarity: # swap direction
|
|
||||||
r *= -1
|
|
||||||
if r < 0:
|
|
||||||
lo, hi = mid, hi
|
|
||||||
else:
|
|
||||||
lo, hi = lo, mid
|
|
||||||
|
|
||||||
# Want satisifiable formula
|
|
||||||
return hi if polarity else lo
|
|
||||||
|
|
||||||
|
|
||||||
def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3):
|
|
||||||
val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order}
|
|
||||||
p_lens = param_lens(stl)
|
|
||||||
def stleval_fact(var, val):
|
|
||||||
l = lens(val)[var]
|
|
||||||
return lambda p: pointwise_robustness(set_params(stl, l.set(p)))(x, 0)
|
|
||||||
|
|
||||||
for var in order:
|
|
||||||
stleval = stleval_fact(var, val)
|
|
||||||
lo, hi = ranges[var]
|
|
||||||
param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var])
|
|
||||||
val[var] = param
|
|
||||||
|
|
||||||
return val
|
|
||||||
|
|
|
||||||
45
synth.py
Normal file
45
synth.py
Normal file
|
|
@ -0,0 +1,45 @@
|
||||||
|
from stl.utils import set_params, param_lens
|
||||||
|
from stl.robustness import pointwise_robustness
|
||||||
|
|
||||||
|
from lenses import lens
|
||||||
|
|
||||||
|
def binsearch(stleval, *, tol=1e-3, lo, hi, polarity):
|
||||||
|
"""Only run search if tightest robustness was positive."""
|
||||||
|
# Only check low since hi taken care of by precondition.
|
||||||
|
# TODO: allow for different polarities
|
||||||
|
rL, rH = stleval(lo), stleval(hi)
|
||||||
|
# Early termination via bounds checks
|
||||||
|
posL, posH = rL > 0, rH > 0
|
||||||
|
if polarity and posL:
|
||||||
|
return lo
|
||||||
|
elif not polarity and posH:
|
||||||
|
return hi
|
||||||
|
|
||||||
|
while hi - lo > tol:
|
||||||
|
mid = lo + (hi - lo) / 2
|
||||||
|
r = stleval(mid)
|
||||||
|
if not polarity: # swap direction
|
||||||
|
r *= -1
|
||||||
|
if r < 0:
|
||||||
|
lo, hi = mid, hi
|
||||||
|
else:
|
||||||
|
lo, hi = lo, mid
|
||||||
|
|
||||||
|
# Want satisifiable formula
|
||||||
|
return hi if polarity else lo
|
||||||
|
|
||||||
|
|
||||||
|
def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3):
|
||||||
|
val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order}
|
||||||
|
p_lens = param_lens(stl)
|
||||||
|
def stleval_fact(var, val):
|
||||||
|
l = lens(val)[var]
|
||||||
|
return lambda p: pointwise_robustness(set_params(stl, l.set(p)))(x, 0)
|
||||||
|
|
||||||
|
for var in order:
|
||||||
|
stleval = stleval_fact(var, val)
|
||||||
|
lo, hi = ranges[var]
|
||||||
|
param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var])
|
||||||
|
val[var] = param
|
||||||
|
|
||||||
|
return val
|
||||||
|
|
@ -11,13 +11,15 @@ ex1 = ("2*A > 3", -1)
|
||||||
ex2 = ("F[0, 1](2*A > 3)", 5)
|
ex2 = ("F[0, 1](2*A > 3)", 5)
|
||||||
ex3 = ("F[1, 0](2*A > 3)", -oo)
|
ex3 = ("F[1, 0](2*A > 3)", -oo)
|
||||||
ex4 = ("G[1, 0](2*A > 3)", oo)
|
ex4 = ("G[1, 0](2*A > 3)", oo)
|
||||||
|
ex5 = ("(A < 0)", -1)
|
||||||
|
ex6 = ("G[0, 0.1](A < 0)", -1)
|
||||||
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
|
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
|
||||||
columns=["A", "B"])
|
columns=["A", "B"])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class TestSTLRobustness(unittest.TestCase):
|
class TestSTLRobustness(unittest.TestCase):
|
||||||
@params(ex1, ex2, ex3, ex4)
|
@params(ex1, ex2, ex3, ex4, ex5, ex6)
|
||||||
def test_stl(self, phi_str, r):
|
def test_stl(self, phi_str, r):
|
||||||
phi = stl.parse(phi_str)
|
phi = stl.parse(phi_str)
|
||||||
stl_eval = stl.robustness.pointwise_robustness(phi)
|
stl_eval = stl.robustness.pointwise_robustness(phi)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import stl
|
import stl
|
||||||
import stl.robustness
|
import stl.robustness
|
||||||
|
import stl.synth
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from nose2.tools import params
|
from nose2.tools import params
|
||||||
import unittest
|
import unittest
|
||||||
|
|
@ -13,9 +14,10 @@ ex2 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 5)},
|
||||||
ex3 = ("F[0, b?](A < 0)", ("b?",), {"b?": (0, 5)},
|
ex3 = ("F[0, b?](A < 0)", ("b?",), {"b?": (0, 5)},
|
||||||
{"b?": True}, {"b?": 5})
|
{"b?": True}, {"b?": 5})
|
||||||
ex4 = ("G[0, b?](A < 0)", ("b?",), {"b?": (0.1, 5)},
|
ex4 = ("G[0, b?](A < 0)", ("b?",), {"b?": (0.1, 5)},
|
||||||
{"b?": True}, {"b?": 0.1})
|
{"b?": False}, {"b?": 0.1})
|
||||||
ex5 = ("F[0, b?](A > 0)", ("b?",), {"b?": (0.1, 5)},
|
ex5 = ("F[0, b?](A > 0)", ("b?",), {"b?": (0.1, 5)},
|
||||||
{"b?": True}, {"b?": 0.1})
|
{"b?": True}, {"b?": 0.1})
|
||||||
|
|
||||||
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
|
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
|
||||||
columns=["A", "B"])
|
columns=["A", "B"])
|
||||||
|
|
||||||
|
|
@ -24,14 +26,14 @@ class TestSTLRobustness(unittest.TestCase):
|
||||||
@params(ex1, ex2, ex3, ex4, ex5)
|
@params(ex1, ex2, ex3, ex4, ex5)
|
||||||
def test_lex_synth(self, phi_str, order, ranges, polarity, val):
|
def test_lex_synth(self, phi_str, order, ranges, polarity, val):
|
||||||
phi = stl.parse(phi_str)
|
phi = stl.parse(phi_str)
|
||||||
val2 = stl.robustness.lex_param_project(
|
val2 = stl.synth.lex_param_project(
|
||||||
phi, x, order=order, ranges=ranges, polarity=polarity)
|
phi, x, order=order, ranges=ranges, polarity=polarity)
|
||||||
|
|
||||||
phi = stl.robustness.set_params(phi, val2)
|
phi2 = stl.utils.set_params(phi, val2)
|
||||||
phi2 = stl.robustness.set_params(phi, val)
|
phi3 = stl.utils.set_params(phi, val)
|
||||||
|
|
||||||
stl_eval = stl.robustness.pointwise_robustness(phi)
|
stl_eval = stl.robustness.pointwise_robustness(phi2)
|
||||||
stl_eval2 = stl.robustness.pointwise_robustness(phi2)
|
stl_eval2 = stl.robustness.pointwise_robustness(phi3)
|
||||||
|
|
||||||
# check that the robustnesses are almost the same
|
# check that the robustnesses are almost the same
|
||||||
self.assertAlmostEqual(stl_eval(x, 0), stl_eval2(x, 0), delta=0.01)
|
self.assertAlmostEqual(stl_eval(x, 0), stl_eval2(x, 0), delta=0.01)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue