Implemented smooth_robustness's stl.F encoding

This commit is contained in:
Marcell Vazquez-Chanlatte 2016-12-10 13:38:41 -08:00
parent 100f48a0ba
commit 8b267fa2c3

View file

@ -1,15 +1,14 @@
# TODO: technically incorrect on 0 robustness since conflates < and >
from functools import singledispatch
from operator import sub, add
import sympy as sym
from lenses import lens
from numpy import arange
from funcy import pairwise, autocurry
import stl.ast
from stl.ast import t_sym
from stl.robustness import op_lookup
@singledispatch
def smooth_robustness(stl, L, h):
@ -20,47 +19,43 @@ def smooth_robustness(stl, L, h):
def _(stl, L, H):
raise NotImplementedError("Call canonicalization function")
def soft_max(rs):
return sym.log(sum(sym.exp(r) for r in rs))
def soft_max(rs, eps=0.1):
B = 10
return sym.log(sum(B**r for r in rs), B)
def LSE(rs):
return soft_max(rs) - sym.log(len(rs))
def LSE(rs, eps=0.1):
B = 10
return soft_max(rs) - sym.log(len(rs), B)
@smooth_robustness.register(stl.Or)
def _(stl, L, h):
rl, rh = list(zip(
*[smooth_robustness(arg, depth) for arg in stl.args]))
*[smooth_robustness(arg, L, h) for arg in stl.args]))
return soft_max(rl), LSE(rh)
@autocurry
def x_ij(L, h, x_i, x_j):
def x_ij(L, h, xi_xj):
x_i, x_j = xi_xj
return (L*h + x_i + x_j)/2
@smooth_robustness.register(stl.F)
def _(stl, L, H):
lo, hi = stl.interval
times = arange(lo, hi, H)
rl, rh = smooth_robustness(stl.arg)
los, his = zip(*[rl.subs({t_sym: t}), rh.subs({t_sym: t}) for t in times])
return LSE(rl), soft_max(map(x_ij(L, H), his))
rl, rh = smooth_robustness(stl.arg, L, H)
los, his = zip(*[(rl.subs({t_sym: t_sym + t}), rh.subs({t_sym: t_sym + t})) for t in times])
return LSE(los), soft_max(map(x_ij(L, H), pairwise(his)))
@smooth_robustness.register(stl.Neg)
def _(stl, L, H):
rl, rh = smooth_robustness(arg)
rl, rh = smooth_robustness(arg, L, H)
return -rh, -rl
op_lookup = {
">": sub,
">=": sub,
"<": lambda x, y: sub(y, x),
"<=": lambda x, y: sub(y, x),
"=": lambda a, b: -abs(a - b),
}
@smooth_robustness.register(stl.LinEq)
def _(stl, L, H):