126 lines
2.7 KiB
Python
126 lines
2.7 KiB
Python
import operator as op
|
|
from functools import reduce
|
|
|
|
import traces
|
|
from lenses import lens, bind
|
|
|
|
import stl.ast
|
|
from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens)
|
|
from stl.types import STL
|
|
|
|
|
|
def f_neg_or_canonical_form(phi: STL) -> STL:
|
|
if isinstance(phi, LinEq):
|
|
return phi
|
|
|
|
children = [f_neg_or_canonical_form(s) for s in phi.children]
|
|
if isinstance(phi, (And, G)):
|
|
children = [Neg(s) for s in children]
|
|
children = tuple(children)
|
|
|
|
if isinstance(phi, Or):
|
|
return Or(children)
|
|
elif isinstance(phi, And):
|
|
return Neg(Or(children))
|
|
elif isinstance(phi, Neg):
|
|
return Neg(children[0])
|
|
elif isinstance(phi, F):
|
|
return F(phi.interval, children[0])
|
|
elif isinstance(phi, G):
|
|
return Neg(F(phi.interval, children[0]))
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
|
|
def _lineq_lipschitz(lineq):
|
|
return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect()))
|
|
|
|
|
|
def linear_stl_lipschitz(phi):
|
|
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
|
return float(max(map(_lineq_lipschitz, phi.lineqs)))
|
|
|
|
|
|
def inline_context(phi, context):
|
|
phi2 = None
|
|
|
|
def update(ap):
|
|
return context.get(ap, ap)
|
|
|
|
while phi2 != phi:
|
|
phi2, phi = phi, AP_lens(phi).modify(update)
|
|
# TODO: this is hack to flatten the AST. Fix!
|
|
return stl.parse(str(phi))
|
|
|
|
|
|
op_lookup = {
|
|
">": op.gt,
|
|
">=": op.ge,
|
|
"<": op.lt,
|
|
"<=": op.le,
|
|
"=": op.eq,
|
|
}
|
|
|
|
|
|
def get_times(x):
|
|
times = set.union(*({t for t, _ in v.items()} for v in x.values()))
|
|
return sorted(times)
|
|
|
|
|
|
def eval_lineq(lineq, x, compact=True):
|
|
def eval_term(term, t):
|
|
return float(term.coeff) * x[term.id.name][t]
|
|
|
|
terms = lens(lineq).Each().terms.Each().collect()
|
|
|
|
def f(t):
|
|
lhs = sum(eval_term(term, t) for term in terms)
|
|
return op_lookup[lineq.op](lhs, lineq.const)
|
|
|
|
output = traces.TimeSeries(map(f, x), domain=x.domain)
|
|
|
|
if compact:
|
|
output.compact()
|
|
return output
|
|
|
|
|
|
def eval_lineqs(phi, x, times=None):
|
|
if times is None:
|
|
times = get_times(x)
|
|
lineqs = phi.lineqs
|
|
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}
|
|
|
|
|
|
# EDSL
|
|
|
|
|
|
def alw(phi, *, lo, hi):
|
|
return G(Interval(lo, hi), phi)
|
|
|
|
|
|
def env(phi, *, lo, hi):
|
|
return F(Interval(lo, hi), phi)
|
|
|
|
|
|
def until(phi1, phi2, *, lo, hi):
|
|
return stl.ast.Until(Interval(lo, hi), phi1, phi2)
|
|
|
|
|
|
def andf(*args):
|
|
return reduce(op.and_, args) if args else stl.TOP
|
|
|
|
|
|
def orf(*args):
|
|
return reduce(op.or_, args) if args else stl.TOP
|
|
|
|
|
|
def implies(x, y):
|
|
return ~x | y
|
|
|
|
|
|
def xor(x, y):
|
|
return (x | y) & ~(x & y)
|
|
|
|
|
|
def iff(x, y):
|
|
return (x & y) | (~x & ~y)
|