import operator as op from functools import reduce import traces from lenses import lens, bind import stl.ast from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens) from stl.types import STL def f_neg_or_canonical_form(phi: STL) -> STL: if isinstance(phi, LinEq): return phi children = [f_neg_or_canonical_form(s) for s in phi.children] if isinstance(phi, (And, G)): children = [Neg(s) for s in children] children = tuple(children) if isinstance(phi, Or): return Or(children) elif isinstance(phi, And): return Neg(Or(children)) elif isinstance(phi, Neg): return Neg(children[0]) elif isinstance(phi, F): return F(phi.interval, children[0]) elif isinstance(phi, G): return Neg(F(phi.interval, children[0])) else: raise NotImplementedError def _lineq_lipschitz(lineq): return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect())) def linear_stl_lipschitz(phi): """Infinity norm lipschitz bound of linear inequality predicate.""" return float(max(map(_lineq_lipschitz, phi.lineqs))) def inline_context(phi, context): phi2 = None def update(ap): return context.get(ap, ap) while phi2 != phi: phi2, phi = phi, AP_lens(phi).modify(update) # TODO: this is hack to flatten the AST. Fix! return stl.parse(str(phi)) op_lookup = { ">": op.gt, ">=": op.ge, "<": op.lt, "<=": op.le, "=": op.eq, } def get_times(x): times = set.union(*({t for t, _ in v.items()} for v in x.values())) return sorted(times) def eval_lineq(lineq, x, compact=True): def eval_term(term, t): return float(term.coeff) * x[term.id.name][t] terms = lens(lineq).Each().terms.Each().collect() def f(t): lhs = sum(eval_term(term, t) for term in terms) return op_lookup[lineq.op](lhs, lineq.const) output = traces.TimeSeries(map(f, x), domain=x.domain) if compact: output.compact() return output def eval_lineqs(phi, x, times=None): if times is None: times = get_times(x) lineqs = phi.lineqs return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs} # EDSL def alw(phi, *, lo, hi): return G(Interval(lo, hi), phi) def env(phi, *, lo, hi): return F(Interval(lo, hi), phi) def until(phi1, phi2, *, lo, hi): return stl.ast.Until(Interval(lo, hi), phi1, phi2) def andf(*args): return reduce(op.and_, args) if args else stl.TOP def orf(*args): return reduce(op.or_, args) if args else stl.TOP def implies(x, y): return ~x | y def xor(x, y): return (x | y) & ~(x & y) def iff(x, y): return (x & y) | (~x & ~y)