switch to symbolic summation for smooth robustness
This commit is contained in:
parent
f5b23637af
commit
5e9239a4f4
1 changed files with 20 additions and 14 deletions
|
|
@ -4,14 +4,14 @@ from functools import singledispatch
|
|||
|
||||
import sympy as sym
|
||||
from numpy import arange
|
||||
from funcy import pairwise, autocurry
|
||||
from funcy import pairwise
|
||||
|
||||
import stl.ast
|
||||
from stl.ast import t_sym
|
||||
from stl.robustness import op_lookup
|
||||
|
||||
@singledispatch
|
||||
def smooth_robustness(stl, L, h, eps):
|
||||
def smooth_robustness(stl, L, h, eps, depth):
|
||||
raise NotImplementedError
|
||||
|
||||
@smooth_robustness.register(stl.And)
|
||||
|
|
@ -35,37 +35,43 @@ def LSE(rs, eps=0.1):
|
|||
|
||||
|
||||
@smooth_robustness.register(stl.Or)
|
||||
def _(stl, L, h, eps):
|
||||
def _(stl, L, h, eps, depth=0):
|
||||
rl, rh = list(zip(
|
||||
*[smooth_robustness(arg, L, h, eps=eps/2) for arg in stl.args]))
|
||||
*[smooth_robustness(arg, L, h, eps=eps/2, depth=depth+1)
|
||||
for arg in stl.args]))
|
||||
return soft_max(rl, eps=eps/2), LSE(rh, eps=eps/2)
|
||||
|
||||
|
||||
@autocurry
|
||||
def x_ij(L, h, xi_xj):
|
||||
x_i, x_j = xi_xj
|
||||
return (L*h + x_i + x_j)/2
|
||||
|
||||
|
||||
def soft_max2(r, eps, lo, hi, L, H, depth):
|
||||
N = sym.ceiling((hi - lo) / H)
|
||||
B = eps_to_base(eps, N)
|
||||
i = sym.Symbol("i_{}".format(depth))
|
||||
x_ij = (L*H + r.subs({t_sym: t_sym+i}) + r.subs({t_sym: t_sym+i+1}))/2
|
||||
return sym.log(sym.summation(B**x_ij, (i, lo, hi)), B)
|
||||
|
||||
|
||||
@smooth_robustness.register(stl.F)
|
||||
def _(stl, L, H, eps):
|
||||
def _(stl, L, H, eps, depth=0):
|
||||
lo, hi = stl.interval
|
||||
times = arange(lo, hi, H)
|
||||
rl, rh = smooth_robustness(stl.arg, L, H, eps=eps/2)
|
||||
los, his = zip(*[(rl.subs({t_sym: t_sym + t}),
|
||||
rh.subs({t_sym: t_sym + t})) for t in times])
|
||||
x_stars = list(map(x_ij(L, H), pairwise(his)))
|
||||
return LSE(los, eps=eps/2), soft_max(x_stars, eps=eps/2)
|
||||
rl, rh = smooth_robustness(stl.arg, L, H, eps=eps/2, depth=depth+1)
|
||||
return (soft_max2(rl, eps/2, lo, hi, L, H, depth),
|
||||
soft_max2(rh, eps/2, lo, hi, L, H, depth))
|
||||
|
||||
|
||||
@smooth_robustness.register(stl.Neg)
|
||||
def _(stl, L, H, eps):
|
||||
rl, rh = smooth_robustness(arg, L, H, eps)
|
||||
def _(stl, L, H, eps, depth=0):
|
||||
rl, rh = smooth_robustness(arg, L, H, eps, depth=depth+1)
|
||||
return -rh, -rl
|
||||
|
||||
|
||||
@smooth_robustness.register(stl.LinEq)
|
||||
def _(stl, L, H, eps):
|
||||
def _(stl, L, H, eps, depth=0):
|
||||
op = op_lookup[stl.op]
|
||||
retval = op(eval_terms(stl), stl.const)
|
||||
return retval, retval
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue