mtl-aas/stl/smooth_robustness.py
Marcell Vazquez-Chanlatte eecb2d409a factor out G and F
2016-12-09 10:57:36 -08:00

85 lines
2 KiB
Python

# TODO: technically incorrect on 0 robustness since conflates < and >
from functools import singledispatch
from operator import sub, add
import sympy as sym
from lenses import lens
import stl.ast
from stl.ast import t_sym
@singledispatch
def smooth_robustness(stl):
raise NotImplementedError
def f1(rs):
return sym.log(sum(sym.exp(r) for r in rs))
def f2(rs):
return sym.log(sum(sym.exp(r) for r in rs)/(len(rs)))
@smooth_robustness.register(stl.Or)
def _(stl, depth=0):
rl, rh = list(zip(*[smooth_robustness(arg, depth) for arg in stl.args]))
return f1(rl), f2(rh)
@smooth_robustness.register(stl.And)
def _(stl, depth=0):
rh, rl = list(zip(*[-smooth_robustness(arg, depth) for arg in stl.args]))
return -f2(rh), -f1(rl)
def F1(r, interval, t):
lo, hi = interval
bounds = (t, lo, hi)
return sym.log(sym.Integral(sym.exp(r), bounds))
def F2(r, interval, t):
lo, hi = interval
return F1(r, interval, t) - sym.log(hi - lo)
@smooth_robustness.register(stl.F)
def _(stl, depth=0):
depth += 1
t = sym.Symbol("t{}".format(depth))
rl, rh = smooth_robustness(stl.arg)
r = (rl.subs({t_sym: t}), rh.subs({t_sym: t}))
return F1(r[0], stl.interval, t), F2(rh[1], stl.interval, t)
@smooth_robustness.register(stl.G)
def _(stl, depth=0):
depth += 1
t = sym.Symbol("t{}".format(depth))
rl, rh = smooth_robustness(stl.arg)
r = (rl.subs({t_sym: t}), rh.subs({t_sym: t}))
return -F2(r[1], stl.interval, t), -F1(r[0], stl.interval, t)
@smooth_robustness.register(stl.Neg)
def _(stl, depth=0):
rl, rh = smooth_robustness(arg)
return -rh, -rl
op_lookup = {
">": sub,
">=": sub,
"<": lambda x, y: sub(y, x),
"<=": lambda x, y: sub(y, x),
"=": lambda a, b: -abs(a - b),
}
@smooth_robustness.register(stl.LinEq)
def _(stl, depth=0):
op = op_lookup[stl.op]
retval = op(eval_terms(stl), stl.const)
return retval, retval
def eval_terms(lineq):
return sum(map(eval_term, lineq.terms))
def eval_term(term):
return term.coeff*sym.Function(term.id.name)(t_sym)