diff --git a/stl/fastboolean_eval.py b/stl/fastboolean_eval.py index 89729ed..4ec1c2e 100644 --- a/stl/fastboolean_eval.py +++ b/stl/fastboolean_eval.py @@ -29,28 +29,27 @@ def _(stl): return sat_comp -@pointwise_satf.register(stl.F) -def _(stl): - lo, hi = stl.interval +def temporal_op(stl, lo, hi, conjunction=False): + f = bitarray.all if conjunction else bitarray.any def sat_comp(x,t): sat = bitarray() for tau in t: tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] - sat.append((pointwise_satf(stl.arg)(x, tau_t)).count() > 0) + sat.append(f(pointwise_satf(stl.arg)(x, tau_t))) return sat return sat_comp +@pointwise_satf.register(stl.F) +def _(stl): + lo, hi = stl.interval + return temporal_op(stl, lo, hi, conjunction=False) + + @pointwise_satf.register(stl.G) def _(stl): lo, hi = stl.interval - def sat_comp(x,t): - sat = bitarray() - for tau in t: - tau_t = [min(tau + t2, x.index[-1]) for t2 in x[lo:hi].index] - sat.append((~(pointwise_satf(stl.arg)(x, tau_t))).count() == 0) - return sat - return sat_comp + return temporal_op(stl, lo, hi, conjunction=True) @pointwise_satf.register(stl.Neg) diff --git a/stl/smooth_robustness.py b/stl/smooth_robustness.py new file mode 100644 index 0000000..a5656a6 --- /dev/null +++ b/stl/smooth_robustness.py @@ -0,0 +1,85 @@ +# TODO: technically incorrect on 0 robustness since conflates < and > + +from functools import singledispatch +from operator import sub, add + +import sympy as sym +from lenses import lens + +import stl.ast +from stl.ast import t_sym + +@singledispatch +def smooth_robustness(stl): + raise NotImplementedError + +def f1(rs): + return sym.log(sum(sym.exp(r) for r in rs)) + +def f2(rs): + return sym.log(sum(sym.exp(r) for r in rs)/(len(rs))) + +@smooth_robustness.register(stl.Or) +def _(stl, depth=0): + rl, rh = list(zip(*[smooth_robustness(arg, depth) for arg in stl.args])) + return f1(rl), f2(rh) + +@smooth_robustness.register(stl.And) +def _(stl, depth=0): + rh, rl = list(zip(*[-smooth_robustness(arg, depth) for arg in stl.args])) + return -f2(rh), -f1(rl) + + +def F1(r, interval, t): + lo, hi = interval + bounds = (t, lo, hi) + return sym.log(sym.Integral(sym.exp(r), bounds)) + +def F2(r, interval, t): + lo, hi = interval + return F1(r, interval, t) - sym.log(hi - lo) + +@smooth_robustness.register(stl.F) +def _(stl, depth=0): + depth += 1 + t = sym.Symbol("t{}".format(depth)) + rl, rh = smooth_robustness(stl.arg) + r = (rl.subs({t_sym: t}), rh.subs({t_sym: t})) + return F1(r[0], stl.interval, t), F2(rh[1], stl.interval, t) + +@smooth_robustness.register(stl.G) +def _(stl, depth=0): + depth += 1 + t = sym.Symbol("t{}".format(depth)) + rl, rh = smooth_robustness(stl.arg) + r = (rl.subs({t_sym: t}), rh.subs({t_sym: t})) + return -F2(r[1], stl.interval, t), -F1(r[0], stl.interval, t) + + +@smooth_robustness.register(stl.Neg) +def _(stl, depth=0): + rl, rh = smooth_robustness(arg) + return -rh, -rl + +op_lookup = { + ">": sub, + ">=": sub, + "<": lambda x, y: sub(y, x), + "<=": lambda x, y: sub(y, x), + "=": lambda a, b: -abs(a - b), +} + + +@smooth_robustness.register(stl.LinEq) +def _(stl, depth=0): + op = op_lookup[stl.op] + retval = op(eval_terms(stl), stl.const) + return retval, retval + + +def eval_terms(lineq): + return sum(map(eval_term, lineq.terms)) + + +def eval_term(term): + return term.coeff*sym.Function(term.id.name)(t_sym)