diff --git a/stl/smooth_robustness.py b/stl/smooth_robustness.py index 933efa0..6fcf669 100644 --- a/stl/smooth_robustness.py +++ b/stl/smooth_robustness.py @@ -1,88 +1,120 @@ # TODO: technically incorrect on 0 robustness since conflates < and > from functools import singledispatch +from collections import namedtuple import sympy as sym from numpy import arange from funcy import pairwise +from lenses import lens import stl.ast from stl.ast import t_sym +from stl.utils import walk from stl.robustness import op_lookup +Param = namedtuple("Param", ["L", "h", "B", "id_map"]) + @singledispatch -def smooth_robustness(stl, L, h, eps, depth): - raise NotImplementedError - -@smooth_robustness.register(stl.And) -@smooth_robustness.register(stl.G) -def _(stl, L, H, eps): - raise NotImplementedError("Call canonicalization function") - -def eps_to_base(eps, N): - return N**(1/eps) - -def soft_max(rs, eps=0.1): - N = len(rs) - B = eps_to_base(eps, N) - return sym.log(sum(B**r for r in rs), B) +def node_base(_, _1, _2): + return sym.e -def LSE(rs, eps=0.1): - N = len(rs) - B = eps_to_base(eps, N) - return soft_max(rs) - sym.log(N, B) +@node_base.register(stl.ast.Or) +def node_base(_, eps, _1): + return len(stl.args)**(1/eps) -@smooth_robustness.register(stl.Or) -def _(stl, L, h, eps, depth=0): - rl, rh = list(zip( - *[smooth_robustness(arg, L, h, eps=eps/2, depth=depth+1) - for arg in stl.args])) - return soft_max(rl, eps=eps/2), LSE(rh, eps=eps/2) - - -def soft_max2(r, eps, lo, hi, L, H, depth): - N = sym.ceiling((hi - lo) / H) - B = eps_to_base(eps, N) - i = sym.Symbol("i_{}".format(depth)) - x_ij = (L*H + r.subs({t_sym: t_sym+i}) + r.subs({t_sym: t_sym+i+1}))/2 - return sym.log(sym.summation(B**x_ij, (i, lo, hi)), B) - -def LSE2(r, eps, lo, hi, H, depth): - N = sym.ceiling((hi - lo) / H) - B = eps_to_base(eps, N) - i = sym.Symbol("i_{}".format(depth)) - x_i = r.subs({t_sym: t_sym+i}) - return sym.log(sym.summation(B**x_i, (i, lo, hi))/N, B) - - - -@smooth_robustness.register(stl.F) -def _(stl, L, H, eps, depth=0): +@node_base.register(stl.ast.F) +def node_base(_, eps, L): lo, hi = stl.interval - times = arange(lo, hi, H) - rl, rh = smooth_robustness(stl.arg, L, H, eps=eps/2, depth=depth+1) - return (LSE2(rl, eps/2, lo, hi, H, depth), - soft_max2(rh, eps/2, lo, hi, L, H, depth)) + return sym.ceil((hi - lo)*L/eps)**(2/eps) -@smooth_robustness.register(stl.Neg) -def _(stl, L, H, eps, depth=0): - rl, rh = smooth_robustness(arg, L, H, eps, depth=depth+1) - return -rh, -rl +def sample_rate(eps, L): + return eps / L -@smooth_robustness.register(stl.LinEq) -def _(stl, L, H, eps, depth=0): - op = op_lookup[stl.op] - retval = op(eval_terms(stl), stl.const) - return retval, retval +def admissible_params(phi, eps, L): + return Param( + L=L, + h=sample_rate(eps, L), + B=max(node_base(n, eps, L) for n in walk(phi)), + id_map={n:i for i, n in enumerate(walk(phi))} + ) +def smooth_robustness(phi, eps, L): + p = admissible_params(phi, eps, L) + lo, hi = beta(phi, p), alpha(phi, p) + return sym.log(lo, B), sym.log(hi, B) + + +# Alpha implementation + +@singledispatch +def alpha(stl, p): + raise NotImplementedError("Call canonicalization function") + def eval_terms(lineq): return sum(map(eval_term, lineq.terms)) def eval_term(term): return term.coeff*sym.Function(term.id.name)(t_sym) + + +@alpha.register(stl.LinEq) +def _(phi, p): + op = op_lookup[phi.op] + B = eps_to_base(eps/depth, N) + x = op(eval_terms(phi), phi.const) + return B**x + + +@alpha.register(stl.Neg) +def _(phi, p): + return 1/beta(phi, p) + + +@alpha.register(stl.Or) +def _(phi, p): + return sum(alpha(psi, p) for psi in psi in phi.args) + + +def F_params(phi, p, r): + hi, lo = phi.interval + N = sym.ceiling((hi - lo) / p.h) + i = sym.Symbol("i_{}".format(p.id_map[phi])) + x = lambda k: r.subs({t_sym: t_sym+k+lo}) + return N, i, x + + +@alpha.register(stl.F) +def _(phi, p): + N, i, x = F_params(phi, p, alpha(phi.arg, p)) + x_ij = sym.sqrt(p.B**(L*h)*x(i)*x(i+1)) + return sym.summation(x_ij, (i, 0, N-1)) + +# Beta implementation + +@singledispatch +def beta(phi, p): + raise NotImplementedError("Call canonicalization function") + +beta.register(stl.LinEq)(alpha) + +@beta.register(stl.Neg) +def _(phi, p): + return 1/alpha(phi, p) + + +@beta.register(stl.Or) +def _(phi, p): + return alpha(phi)/len(phi.args) + + +@beta.register(stl.F) +def _(phi, p): + N, i, x = F_params(phi, p, beta(phi.arg, p)) + return sym.summation(x(i), (i, 0, N))