mtl-aas/stl/utils.py
2017-10-31 18:29:31 -07:00

126 lines
2.7 KiB
Python

import operator as op
from functools import reduce
import traces
from lenses import lens, bind
import stl.ast
from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens)
from stl.types import STL
def f_neg_or_canonical_form(phi: STL) -> STL:
if isinstance(phi, LinEq):
return phi
children = [f_neg_or_canonical_form(s) for s in phi.children]
if isinstance(phi, (And, G)):
children = [Neg(s) for s in children]
children = tuple(children)
if isinstance(phi, Or):
return Or(children)
elif isinstance(phi, And):
return Neg(Or(children))
elif isinstance(phi, Neg):
return Neg(children[0])
elif isinstance(phi, F):
return F(phi.interval, children[0])
elif isinstance(phi, G):
return Neg(F(phi.interval, children[0]))
else:
raise NotImplementedError
def _lineq_lipschitz(lineq):
return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect()))
def linear_stl_lipschitz(phi):
"""Infinity norm lipschitz bound of linear inequality predicate."""
return float(max(map(_lineq_lipschitz, phi.lineqs)))
def inline_context(phi, context):
phi2 = None
def update(ap):
return context.get(ap, ap)
while phi2 != phi:
phi2, phi = phi, AP_lens(phi).modify(update)
# TODO: this is hack to flatten the AST. Fix!
return stl.parse(str(phi))
op_lookup = {
">": op.gt,
">=": op.ge,
"<": op.lt,
"<=": op.le,
"=": op.eq,
}
def get_times(x):
times = set.union(*({t for t, _ in v.items()} for v in x.values()))
return sorted(times)
def eval_lineq(lineq, x, compact=True):
def eval_term(term, t):
return float(term.coeff) * x[term.id.name][t]
terms = lens(lineq).Each().terms.Each().collect()
def f(t):
lhs = sum(eval_term(term, t) for term in terms)
return op_lookup[lineq.op](lhs, lineq.const)
output = traces.TimeSeries(map(f, x), domain=x.domain)
if compact:
output.compact()
return output
def eval_lineqs(phi, x, times=None):
if times is None:
times = get_times(x)
lineqs = phi.lineqs
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}
# EDSL
def alw(phi, *, lo, hi):
return G(Interval(lo, hi), phi)
def env(phi, *, lo, hi):
return F(Interval(lo, hi), phi)
def until(phi1, phi2, *, lo, hi):
return stl.ast.Until(Interval(lo, hi), phi1, phi2)
def andf(*args):
return reduce(op.and_, args) if args else stl.TOP
def orf(*args):
return reduce(op.or_, args) if args else stl.TOP
def implies(x, y):
return ~x | y
def xor(x, y):
return (x | y) & ~(x & y)
def iff(x, y):
return (x & y) | (~x & ~y)