precompute lineq timeseries during boolean evaluation

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-04-24 20:17:27 -07:00
parent 8f5035a9e3
commit 00ec325589
3 changed files with 72 additions and 24 deletions

View file

@ -8,23 +8,32 @@ import funcy as fn
from lenses import lens from lenses import lens
import stl.ast import stl.ast
import stl
oo = float('inf') oo = float('inf')
def pointwise_sat(phi):
ap_names = [z.id.name for z in stl.utils.AP_lens(phi).get_all()]
def _eval_stl(x, t):
evaluated = stl.utils.eval_lineqs(phi, x)
evaluated.update(fn.project(x, ap_names))
return eval_stl(phi)(evaluated, t)
return _eval_stl
@singledispatch @singledispatch
def pointwise_sat(stl): def eval_stl(stl):
raise NotImplementedError raise NotImplementedError
@pointwise_sat.register(stl.Or) @eval_stl.register(stl.Or)
def _(stl): def _(stl):
fs = [pointwise_sat(arg) for arg in stl.args] fs = [eval_stl(arg) for arg in stl.args]
return lambda x, t: any(f(x, t) for f in fs) return lambda x, t: any(f(x, t) for f in fs)
@pointwise_sat.register(stl.And) @eval_stl.register(stl.And)
def _(stl): def _(stl):
fs = [pointwise_sat(arg) for arg in stl.args] fs = [eval_stl(arg) for arg in stl.args]
return lambda x, t: all(f(x, t) for f in fs) return lambda x, t: all(f(x, t) for f in fs)
@ -33,7 +42,10 @@ def get_times(x, tau, lo=None, hi=None):
lo = min(v.first()[0] for v in x.values()) lo = min(v.first()[0] for v in x.values())
if hi is None or hi is oo: if hi is None or hi is oo:
hi = max(v.last()[0] for v in x.values()) hi = max(v.last()[0] for v in x.values())
end = min(v.domain.end() for v in x.values()) try:
end = min(v.domain.end() for v in x.values())
except:
import pdb; pdb.set_trace()
hi = hi + tau if hi + tau <= end else end hi = hi + tau if hi + tau <= end else end
lo = lo + tau if lo + tau <= end else end lo = lo + tau if lo + tau <= end else end
@ -46,10 +58,10 @@ def get_times(x, tau, lo=None, hi=None):
return sorted(set(fn.pluck(0, all_times))) return sorted(set(fn.pluck(0, all_times)))
@pointwise_sat.register(stl.Until) @eval_stl.register(stl.Until)
def _(stl): def _(stl):
def _until(x, t): def _until(x, t):
f1, f2 = pointwise_sat(stl.arg1), pointwise_sat(stl.arg2) f1, f2 = eval_stl(stl.arg1), eval_stl(stl.arg2)
for tau in get_times(x, t): for tau in get_times(x, t):
if not f1(x, tau): if not f1(x, tau):
return f2(x, tau) return f2(x, tau)
@ -65,23 +77,23 @@ def eval_unary_temporal_op(phi, always=True):
return lambda x, t: retval return lambda x, t: retval
if hi == lo: if hi == lo:
return lambda x, t: f(x, t) return lambda x, t: f(x, t)
f = pointwise_sat(phi.arg) f = eval_stl(phi.arg)
return lambda x, t: fold(f(x, tau) for tau in get_times(x, t, lo, hi)) return lambda x, t: fold(f(x, tau) for tau in get_times(x, t, lo, hi))
@pointwise_sat.register(stl.F) @eval_stl.register(stl.F)
def _(phi): def _(phi):
return eval_unary_temporal_op(phi, always=False) return eval_unary_temporal_op(phi, always=False)
@pointwise_sat.register(stl.G) @eval_stl.register(stl.G)
def _(phi): def _(phi):
return eval_unary_temporal_op(phi, always=True) return eval_unary_temporal_op(phi, always=True)
@pointwise_sat.register(stl.Neg) @eval_stl.register(stl.Neg)
def _(stl): def _(stl):
f = pointwise_sat(stl.arg) f = eval_stl(stl.arg)
return lambda x, t: not f(x, t) return lambda x, t: not f(x, t)
@ -94,22 +106,20 @@ op_lookup = {
} }
@pointwise_sat.register(stl.AtomicPred) @eval_stl.register(stl.AtomicPred)
def _(stl): def _(stl):
return lambda x, t: x[str(stl.id)][t] return lambda x, t: x[str(stl.id)][t]
@pointwise_sat.register(stl.LinEq) @eval_stl.register(stl.LinEq)
def _(stl): def _(lineq):
op = op_lookup[stl.op] return lambda x, t: x[lineq][t]
return lambda x, t: op(eval_terms(stl, x, t), stl.const)
def eval_terms(lineq, x, t): def eval_terms(lineq, x, t):
psi = lens(lineq).terms.each_().modify(eval_term(x, t)) terms = lens(lineq).terms.each_().get_all()
return sum(psi.terms) return sum(eval_term(term, x, t) for term in terms)
def eval_term(x, t): def eval_term(term, x, t):
# TODO(lift interpolation much higher) return float(term.coeff)*x[term.id.name][t]
return lambda term: term.coeff*x[term.id.name][t]

View file

@ -1,7 +1,7 @@
import operator as op import operator as op
from stl.utils import set_params, param_lens from stl.utils import set_params, param_lens
from stl import pointwise_sat from stl.boolean_eval import pointwise_sat
from lenses import lens from lenses import lens

View file

@ -6,6 +6,7 @@ from functools import reduce
from lenses import lens, Lens from lenses import lens, Lens
import funcy as fn import funcy as fn
import sympy import sympy
import traces
import stl.ast import stl.ast
from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg, from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg,
@ -138,6 +139,43 @@ def inline_context(phi, context):
# TODO: this is hack to flatten the AST. Fix! # TODO: this is hack to flatten the AST. Fix!
return stl.parse(str(phi)) return stl.parse(str(phi))
op_lookup = {
">": op.gt,
">=": op.ge,
"<": op.lt,
"<=": op.le,
"=": op.eq,
}
def get_times(x):
times = set.union(*({t for t, _ in v.items()} for v in x.values()))
return sorted(times)
def eval_lineq(lineq, x, times=None, compact=True):
if times is None:
times = get_times(x)
def eval_term(term, t):
return float(term.coeff)*x[term.id.name][t]
output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1]))
terms = lens(lineq).terms.each_().get_all()
for t in times:
lhs = sum(eval_term(term, t) for term in terms)
output[t] = op_lookup[lineq.op](lhs, lineq.const)
if compact:
output.compact()
return output
def eval_lineqs(phi, x, times=None):
if times is None:
times = get_times(x)
lineqs = set(lineq_lens(phi).get_all())
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}
# EDSL # EDSL
def alw(phi, *, lo, hi): def alw(phi, *, lo, hi):