# TODO: figure out how to deduplicate this with robustness # - Abstract as working on distributive lattice import operator as op from collections import defaultdict from functools import reduce, singledispatch import funcy as fn from discrete_signals import signal from mtl import ast OO = float('inf') CONST_FALSE = signal([(0, -1)], start=-OO, end=OO, tag=ast.BOT) CONST_TRUE = signal([(0, 1)], start=-OO, end=OO, tag=ast.TOP) def to_signal(ts_mapping): start = min(fn.pluck(0, fn.cat(ts_mapping.values()))) assert start >= 0 signals = (signal(v, start, OO, tag=k) for k, v in ts_mapping.items()) return reduce(op.or_, signals) def interp(sig, t, tag=None): # TODO: return function that interpolates the whole signal. sig = sig.project({tag}) idx = sig.data.bisect_right(t) - 1 if idx < 0: return None else: key = sig.data.keys()[idx] return sig[key][tag] def dense_compose(sig1, sig2, init=None): sig12 = sig1 | sig2 tags = sig12.tags def _dense_compose(): state = {tag: init for tag in tags} for t, val in sig12.items(): state = {k: val.get(k, state[k]) for k in tags} yield t, state data = list(_dense_compose()) return sig12.evolve(data=data) def booleanize_signal(sig): return sig.transform(lambda mapping: defaultdict( lambda: None, {k: 2*int(v) - 1 for k, v in mapping.items()} )) def pointwise_sat(phi, dt=0.1): f = eval_mtl(phi, dt) def _eval_mtl(x, t=0, quantitative=False): sig = to_signal(x) if not quantitative: sig = booleanize_signal(sig) if t is None: res = [(t, v[phi]) for t, v in f(sig).items()] return res if quantitative else [(t, v > 0) for t, v in res] if t is False: t = f(sig).items()[0][0] res = interp(f(sig), t, phi) return res if quantitative else res > 0 return _eval_mtl @singledispatch def eval_mtl(phi, dt): raise NotImplementedError @eval_mtl.register(ast.And) def eval_mtl_and(phi, dt): fs = [eval_mtl(arg, dt) for arg in phi.args] def _eval(x): sigs = [f(x) for f in fs] sig = reduce(lambda x, y: dense_compose(x, y, init=OO), sigs) return sig.map(lambda v: min(v.values()), tag=phi) return _eval def apply_weak_until(left_key, right_key, sig): prev, max_right = OO, -OO for t in reversed(sig.times()): left, right = interp(sig, t, left_key), interp(sig, t, right_key) max_right = max(max_right, right) prev = max(right, min(left, prev), -max_right) yield (t, prev) @eval_mtl.register(ast.WeakUntil) def eval_mtl_until(phi, dt): f1, f2 = eval_mtl(phi.arg1, dt), eval_mtl(phi.arg2, dt) def _eval(x): sig = dense_compose(f1(x), f2(x), init=-OO) data = apply_weak_until(phi.arg1, phi.arg2, sig) return signal(data, x.start, OO, tag=phi) return _eval @eval_mtl.register(ast.G) def eval_mtl_g(phi, dt): f = eval_mtl(phi.arg, dt) a, b = phi.interval if b < a: return lambda x: CONST_TRUE.retag({ast.TOP: phi}) def _min(val): return min(val[phi.arg]) def _eval(x): return f(x).rolling(a, b).map(_min, tag=phi) return _eval @eval_mtl.register(ast.Neg) def eval_mtl_neg(phi, dt): f = eval_mtl(phi.arg, dt) def _eval(x): return f(x).map(lambda v: -v[phi.arg], tag=phi) return _eval @eval_mtl.register(ast.Next) def eval_mtl_next(phi, dt): f = eval_mtl(phi.arg, dt) def _eval(x): return (f(x) << dt).retag({phi.arg: phi}) return _eval @eval_mtl.register(ast.AtomicPred) def eval_mtl_ap(phi, _): def _eval(x): return x.project({phi.id}).retag({phi.id: phi}) return _eval @eval_mtl.register(type(ast.BOT)) def eval_mtl_bot(_, _1): return lambda x: CONST_FALSE