factor out G and F
This commit is contained in:
parent
7064f2b4a5
commit
eecb2d409a
2 changed files with 95 additions and 11 deletions
|
|
@ -29,28 +29,27 @@ def _(stl):
|
||||||
return sat_comp
|
return sat_comp
|
||||||
|
|
||||||
|
|
||||||
@pointwise_satf.register(stl.F)
|
def temporal_op(stl, lo, hi, conjunction=False):
|
||||||
def _(stl):
|
f = bitarray.all if conjunction else bitarray.any
|
||||||
lo, hi = stl.interval
|
|
||||||
def sat_comp(x,t):
|
def sat_comp(x,t):
|
||||||
sat = bitarray()
|
sat = bitarray()
|
||||||
for tau in t:
|
for tau in t:
|
||||||
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
||||||
sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0)
|
sat.append(f(pointwise_satf(stl.arg)(x, tau_t)))
|
||||||
return sat
|
return sat
|
||||||
return sat_comp
|
return sat_comp
|
||||||
|
|
||||||
|
|
||||||
|
@pointwise_satf.register(stl.F)
|
||||||
|
def _(stl):
|
||||||
|
lo, hi = stl.interval
|
||||||
|
return temporal_op(stl, lo, hi, conjunction=False)
|
||||||
|
|
||||||
|
|
||||||
@pointwise_satf.register(stl.G)
|
@pointwise_satf.register(stl.G)
|
||||||
def _(stl):
|
def _(stl):
|
||||||
lo, hi = stl.interval
|
lo, hi = stl.interval
|
||||||
def sat_comp(x,t):
|
return temporal_op(stl, lo, hi, conjunction=True)
|
||||||
sat = bitarray()
|
|
||||||
for tau in t:
|
|
||||||
tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index]
|
|
||||||
sat.append((~(pointwise_satf(stl.arg)(x, tau_t))).count() == 0)
|
|
||||||
return sat
|
|
||||||
return sat_comp
|
|
||||||
|
|
||||||
|
|
||||||
@pointwise_satf.register(stl.Neg)
|
@pointwise_satf.register(stl.Neg)
|
||||||
|
|
|
||||||
85
stl/smooth_robustness.py
Normal file
85
stl/smooth_robustness.py
Normal file
|
|
@ -0,0 +1,85 @@
|
||||||
|
# TODO: technically incorrect on 0 robustness since conflates < and >
|
||||||
|
|
||||||
|
from functools import singledispatch
|
||||||
|
from operator import sub, add
|
||||||
|
|
||||||
|
import sympy as sym
|
||||||
|
from lenses import lens
|
||||||
|
|
||||||
|
import stl.ast
|
||||||
|
from stl.ast import t_sym
|
||||||
|
|
||||||
|
@singledispatch
|
||||||
|
def smooth_robustness(stl):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def f1(rs):
|
||||||
|
return sym.log(sum(sym.exp(r) for r in rs))
|
||||||
|
|
||||||
|
def f2(rs):
|
||||||
|
return sym.log(sum(sym.exp(r) for r in rs)/(len(rs)))
|
||||||
|
|
||||||
|
@smooth_robustness.register(stl.Or)
|
||||||
|
def _(stl, depth=0):
|
||||||
|
rl, rh = list(zip(*[smooth_robustness(arg, depth) for arg in stl.args]))
|
||||||
|
return f1(rl), f2(rh)
|
||||||
|
|
||||||
|
@smooth_robustness.register(stl.And)
|
||||||
|
def _(stl, depth=0):
|
||||||
|
rh, rl = list(zip(*[-smooth_robustness(arg, depth) for arg in stl.args]))
|
||||||
|
return -f2(rh), -f1(rl)
|
||||||
|
|
||||||
|
|
||||||
|
def F1(r, interval, t):
|
||||||
|
lo, hi = interval
|
||||||
|
bounds = (t, lo, hi)
|
||||||
|
return sym.log(sym.Integral(sym.exp(r), bounds))
|
||||||
|
|
||||||
|
def F2(r, interval, t):
|
||||||
|
lo, hi = interval
|
||||||
|
return F1(r, interval, t) - sym.log(hi - lo)
|
||||||
|
|
||||||
|
@smooth_robustness.register(stl.F)
|
||||||
|
def _(stl, depth=0):
|
||||||
|
depth += 1
|
||||||
|
t = sym.Symbol("t{}".format(depth))
|
||||||
|
rl, rh = smooth_robustness(stl.arg)
|
||||||
|
r = (rl.subs({t_sym: t}), rh.subs({t_sym: t}))
|
||||||
|
return F1(r[0], stl.interval, t), F2(rh[1], stl.interval, t)
|
||||||
|
|
||||||
|
@smooth_robustness.register(stl.G)
|
||||||
|
def _(stl, depth=0):
|
||||||
|
depth += 1
|
||||||
|
t = sym.Symbol("t{}".format(depth))
|
||||||
|
rl, rh = smooth_robustness(stl.arg)
|
||||||
|
r = (rl.subs({t_sym: t}), rh.subs({t_sym: t}))
|
||||||
|
return -F2(r[1], stl.interval, t), -F1(r[0], stl.interval, t)
|
||||||
|
|
||||||
|
|
||||||
|
@smooth_robustness.register(stl.Neg)
|
||||||
|
def _(stl, depth=0):
|
||||||
|
rl, rh = smooth_robustness(arg)
|
||||||
|
return -rh, -rl
|
||||||
|
|
||||||
|
op_lookup = {
|
||||||
|
">": sub,
|
||||||
|
">=": sub,
|
||||||
|
"<": lambda x, y: sub(y, x),
|
||||||
|
"<=": lambda x, y: sub(y, x),
|
||||||
|
"=": lambda a, b: -abs(a - b),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@smooth_robustness.register(stl.LinEq)
|
||||||
|
def _(stl, depth=0):
|
||||||
|
op = op_lookup[stl.op]
|
||||||
|
retval = op(eval_terms(stl), stl.const)
|
||||||
|
return retval, retval
|
||||||
|
|
||||||
|
|
||||||
|
def eval_terms(lineq):
|
||||||
|
return sum(map(eval_term, lineq.terms))
|
||||||
|
|
||||||
|
|
||||||
|
def eval_term(term):
|
||||||
|
return term.coeff*sym.Function(term.id.name)(t_sym)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue