implemented smooth_robustness for constant base
This commit is contained in:
parent
5fde483116
commit
f3d118f01e
1 changed files with 90 additions and 58 deletions
|
|
@ -1,88 +1,120 @@
|
||||||
# TODO: technically incorrect on 0 robustness since conflates < and >
|
# TODO: technically incorrect on 0 robustness since conflates < and >
|
||||||
|
|
||||||
from functools import singledispatch
|
from functools import singledispatch
|
||||||
|
from collections import namedtuple
|
||||||
|
|
||||||
import sympy as sym
|
import sympy as sym
|
||||||
from numpy import arange
|
from numpy import arange
|
||||||
from funcy import pairwise
|
from funcy import pairwise
|
||||||
|
from lenses import lens
|
||||||
|
|
||||||
import stl.ast
|
import stl.ast
|
||||||
from stl.ast import t_sym
|
from stl.ast import t_sym
|
||||||
|
from stl.utils import walk
|
||||||
from stl.robustness import op_lookup
|
from stl.robustness import op_lookup
|
||||||
|
|
||||||
|
Param = namedtuple("Param", ["L", "h", "B", "id_map"])
|
||||||
|
|
||||||
@singledispatch
|
@singledispatch
|
||||||
def smooth_robustness(stl, L, h, eps, depth):
|
def node_base(_, _1, _2):
|
||||||
raise NotImplementedError
|
return sym.e
|
||||||
|
|
||||||
@smooth_robustness.register(stl.And)
|
|
||||||
@smooth_robustness.register(stl.G)
|
|
||||||
def _(stl, L, H, eps):
|
|
||||||
raise NotImplementedError("Call canonicalization function")
|
|
||||||
|
|
||||||
def eps_to_base(eps, N):
|
|
||||||
return N**(1/eps)
|
|
||||||
|
|
||||||
def soft_max(rs, eps=0.1):
|
|
||||||
N = len(rs)
|
|
||||||
B = eps_to_base(eps, N)
|
|
||||||
return sym.log(sum(B**r for r in rs), B)
|
|
||||||
|
|
||||||
|
|
||||||
def LSE(rs, eps=0.1):
|
@node_base.register(stl.ast.Or)
|
||||||
N = len(rs)
|
def node_base(_, eps, _1):
|
||||||
B = eps_to_base(eps, N)
|
return len(stl.args)**(1/eps)
|
||||||
return soft_max(rs) - sym.log(N, B)
|
|
||||||
|
|
||||||
|
|
||||||
@smooth_robustness.register(stl.Or)
|
@node_base.register(stl.ast.F)
|
||||||
def _(stl, L, h, eps, depth=0):
|
def node_base(_, eps, L):
|
||||||
rl, rh = list(zip(
|
|
||||||
*[smooth_robustness(arg, L, h, eps=eps/2, depth=depth+1)
|
|
||||||
for arg in stl.args]))
|
|
||||||
return soft_max(rl, eps=eps/2), LSE(rh, eps=eps/2)
|
|
||||||
|
|
||||||
|
|
||||||
def soft_max2(r, eps, lo, hi, L, H, depth):
|
|
||||||
N = sym.ceiling((hi - lo) / H)
|
|
||||||
B = eps_to_base(eps, N)
|
|
||||||
i = sym.Symbol("i_{}".format(depth))
|
|
||||||
x_ij = (L*H + r.subs({t_sym: t_sym+i}) + r.subs({t_sym: t_sym+i+1}))/2
|
|
||||||
return sym.log(sym.summation(B**x_ij, (i, lo, hi)), B)
|
|
||||||
|
|
||||||
def LSE2(r, eps, lo, hi, H, depth):
|
|
||||||
N = sym.ceiling((hi - lo) / H)
|
|
||||||
B = eps_to_base(eps, N)
|
|
||||||
i = sym.Symbol("i_{}".format(depth))
|
|
||||||
x_i = r.subs({t_sym: t_sym+i})
|
|
||||||
return sym.log(sym.summation(B**x_i, (i, lo, hi))/N, B)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@smooth_robustness.register(stl.F)
|
|
||||||
def _(stl, L, H, eps, depth=0):
|
|
||||||
lo, hi = stl.interval
|
lo, hi = stl.interval
|
||||||
times = arange(lo, hi, H)
|
return sym.ceil((hi - lo)*L/eps)**(2/eps)
|
||||||
rl, rh = smooth_robustness(stl.arg, L, H, eps=eps/2, depth=depth+1)
|
|
||||||
return (LSE2(rl, eps/2, lo, hi, H, depth),
|
|
||||||
soft_max2(rh, eps/2, lo, hi, L, H, depth))
|
|
||||||
|
|
||||||
|
|
||||||
@smooth_robustness.register(stl.Neg)
|
def sample_rate(eps, L):
|
||||||
def _(stl, L, H, eps, depth=0):
|
return eps / L
|
||||||
rl, rh = smooth_robustness(arg, L, H, eps, depth=depth+1)
|
|
||||||
return -rh, -rl
|
|
||||||
|
|
||||||
|
|
||||||
@smooth_robustness.register(stl.LinEq)
|
def admissible_params(phi, eps, L):
|
||||||
def _(stl, L, H, eps, depth=0):
|
return Param(
|
||||||
op = op_lookup[stl.op]
|
L=L,
|
||||||
retval = op(eval_terms(stl), stl.const)
|
h=sample_rate(eps, L),
|
||||||
return retval, retval
|
B=max(node_base(n, eps, L) for n in walk(phi)),
|
||||||
|
id_map={n:i for i, n in enumerate(walk(phi))}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def smooth_robustness(phi, eps, L):
|
||||||
|
p = admissible_params(phi, eps, L)
|
||||||
|
lo, hi = beta(phi, p), alpha(phi, p)
|
||||||
|
return sym.log(lo, B), sym.log(hi, B)
|
||||||
|
|
||||||
|
|
||||||
|
# Alpha implementation
|
||||||
|
|
||||||
|
@singledispatch
|
||||||
|
def alpha(stl, p):
|
||||||
|
raise NotImplementedError("Call canonicalization function")
|
||||||
|
|
||||||
def eval_terms(lineq):
|
def eval_terms(lineq):
|
||||||
return sum(map(eval_term, lineq.terms))
|
return sum(map(eval_term, lineq.terms))
|
||||||
|
|
||||||
|
|
||||||
def eval_term(term):
|
def eval_term(term):
|
||||||
return term.coeff*sym.Function(term.id.name)(t_sym)
|
return term.coeff*sym.Function(term.id.name)(t_sym)
|
||||||
|
|
||||||
|
|
||||||
|
@alpha.register(stl.LinEq)
|
||||||
|
def _(phi, p):
|
||||||
|
op = op_lookup[phi.op]
|
||||||
|
B = eps_to_base(eps/depth, N)
|
||||||
|
x = op(eval_terms(phi), phi.const)
|
||||||
|
return B**x
|
||||||
|
|
||||||
|
|
||||||
|
@alpha.register(stl.Neg)
|
||||||
|
def _(phi, p):
|
||||||
|
return 1/beta(phi, p)
|
||||||
|
|
||||||
|
|
||||||
|
@alpha.register(stl.Or)
|
||||||
|
def _(phi, p):
|
||||||
|
return sum(alpha(psi, p) for psi in psi in phi.args)
|
||||||
|
|
||||||
|
|
||||||
|
def F_params(phi, p, r):
|
||||||
|
hi, lo = phi.interval
|
||||||
|
N = sym.ceiling((hi - lo) / p.h)
|
||||||
|
i = sym.Symbol("i_{}".format(p.id_map[phi]))
|
||||||
|
x = lambda k: r.subs({t_sym: t_sym+k+lo})
|
||||||
|
return N, i, x
|
||||||
|
|
||||||
|
|
||||||
|
@alpha.register(stl.F)
|
||||||
|
def _(phi, p):
|
||||||
|
N, i, x = F_params(phi, p, alpha(phi.arg, p))
|
||||||
|
x_ij = sym.sqrt(p.B**(L*h)*x(i)*x(i+1))
|
||||||
|
return sym.summation(x_ij, (i, 0, N-1))
|
||||||
|
|
||||||
|
# Beta implementation
|
||||||
|
|
||||||
|
@singledispatch
|
||||||
|
def beta(phi, p):
|
||||||
|
raise NotImplementedError("Call canonicalization function")
|
||||||
|
|
||||||
|
beta.register(stl.LinEq)(alpha)
|
||||||
|
|
||||||
|
@beta.register(stl.Neg)
|
||||||
|
def _(phi, p):
|
||||||
|
return 1/alpha(phi, p)
|
||||||
|
|
||||||
|
|
||||||
|
@beta.register(stl.Or)
|
||||||
|
def _(phi, p):
|
||||||
|
return alpha(phi)/len(phi.args)
|
||||||
|
|
||||||
|
|
||||||
|
@beta.register(stl.F)
|
||||||
|
def _(phi, p):
|
||||||
|
N, i, x = F_params(phi, p, beta(phi.arg, p))
|
||||||
|
return sym.summation(x(i), (i, 0, N))
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue