Implemented smooth_robustness's stl.F encoding
This commit is contained in:
parent
100f48a0ba
commit
8b267fa2c3
1 changed files with 15 additions and 20 deletions
|
|
@ -1,15 +1,14 @@
|
||||||
# TODO: technically incorrect on 0 robustness since conflates < and >
|
# TODO: technically incorrect on 0 robustness since conflates < and >
|
||||||
|
|
||||||
from functools import singledispatch
|
from functools import singledispatch
|
||||||
from operator import sub, add
|
|
||||||
|
|
||||||
import sympy as sym
|
import sympy as sym
|
||||||
from lenses import lens
|
|
||||||
from numpy import arange
|
from numpy import arange
|
||||||
from funcy import pairwise, autocurry
|
from funcy import pairwise, autocurry
|
||||||
|
|
||||||
import stl.ast
|
import stl.ast
|
||||||
from stl.ast import t_sym
|
from stl.ast import t_sym
|
||||||
|
from stl.robustness import op_lookup
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def smooth_robustness(stl, L, h):
|
def smooth_robustness(stl, L, h):
|
||||||
|
|
@ -20,47 +19,43 @@ def smooth_robustness(stl, L, h):
|
||||||
def _(stl, L, H):
|
def _(stl, L, H):
|
||||||
raise NotImplementedError("Call canonicalization function")
|
raise NotImplementedError("Call canonicalization function")
|
||||||
|
|
||||||
def soft_max(rs):
|
def soft_max(rs, eps=0.1):
|
||||||
return sym.log(sum(sym.exp(r) for r in rs))
|
B = 10
|
||||||
|
return sym.log(sum(B**r for r in rs), B)
|
||||||
|
|
||||||
|
|
||||||
def LSE(rs):
|
def LSE(rs, eps=0.1):
|
||||||
return soft_max(rs) - sym.log(len(rs))
|
B = 10
|
||||||
|
return soft_max(rs) - sym.log(len(rs), B)
|
||||||
|
|
||||||
|
|
||||||
@smooth_robustness.register(stl.Or)
|
@smooth_robustness.register(stl.Or)
|
||||||
def _(stl, L, h):
|
def _(stl, L, h):
|
||||||
rl, rh = list(zip(
|
rl, rh = list(zip(
|
||||||
*[smooth_robustness(arg, depth) for arg in stl.args]))
|
*[smooth_robustness(arg, L, h) for arg in stl.args]))
|
||||||
return soft_max(rl), LSE(rh)
|
return soft_max(rl), LSE(rh)
|
||||||
|
|
||||||
|
|
||||||
@autocurry
|
@autocurry
|
||||||
def x_ij(L, h, x_i, x_j):
|
def x_ij(L, h, xi_xj):
|
||||||
|
x_i, x_j = xi_xj
|
||||||
return (L*h + x_i + x_j)/2
|
return (L*h + x_i + x_j)/2
|
||||||
|
|
||||||
|
|
||||||
@smooth_robustness.register(stl.F)
|
@smooth_robustness.register(stl.F)
|
||||||
def _(stl, L, H):
|
def _(stl, L, H):
|
||||||
lo, hi = stl.interval
|
lo, hi = stl.interval
|
||||||
times = arange(lo, hi, H)
|
times = arange(lo, hi, H)
|
||||||
rl, rh = smooth_robustness(stl.arg)
|
rl, rh = smooth_robustness(stl.arg, L, H)
|
||||||
los, his = zip(*[rl.subs({t_sym: t}), rh.subs({t_sym: t}) for t in times])
|
los, his = zip(*[(rl.subs({t_sym: t_sym + t}), rh.subs({t_sym: t_sym + t})) for t in times])
|
||||||
return LSE(rl), soft_max(map(x_ij(L, H), his))
|
return LSE(los), soft_max(map(x_ij(L, H), pairwise(his)))
|
||||||
|
|
||||||
|
|
||||||
@smooth_robustness.register(stl.Neg)
|
@smooth_robustness.register(stl.Neg)
|
||||||
def _(stl, L, H):
|
def _(stl, L, H):
|
||||||
rl, rh = smooth_robustness(arg)
|
rl, rh = smooth_robustness(arg, L, H)
|
||||||
return -rh, -rl
|
return -rh, -rl
|
||||||
|
|
||||||
op_lookup = {
|
|
||||||
">": sub,
|
|
||||||
">=": sub,
|
|
||||||
"<": lambda x, y: sub(y, x),
|
|
||||||
"<=": lambda x, y: sub(y, x),
|
|
||||||
"=": lambda a, b: -abs(a - b),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@smooth_robustness.register(stl.LinEq)
|
@smooth_robustness.register(stl.LinEq)
|
||||||
def _(stl, L, H):
|
def _(stl, L, H):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue