upgrade to new lens version

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-15 00:59:58 -07:00
parent 212a8c195a
commit d78037816b
5 changed files with 86 additions and 61 deletions

View file

@ -1,9 +1,10 @@
from typing import List, Type, Dict, Mapping, T
from typing import List, Type, Dict, Mapping, T, TypeVar
from collections import deque
import operator as op
from functools import reduce
from lenses import lens, Lens
import lenses
from lenses import lens
import funcy as fn
import sympy
import traces
@ -13,59 +14,52 @@ from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg,
AtomicPred)
from stl.types import STL, STL_Generator, MTL
Lens = TypeVar('Lens')
def walk(phi:STL) -> STL_Generator:
"""DSF walk of the AST."""
"""Walk of the AST."""
pop = deque.pop
children = deque([phi])
while len(children) > 0:
node = pop(children)
yield node
children.extend(node.children())
children.extend(node.children)
def vars_in_phi(phi):
focus = stl.terms_lens(phi)
return set(focus.tuple_(lens().id, lens().time).get_all())
return set(focus.tuple_(lens.id, lens.time).get_all())
def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
ast_types = set(args)
return lambda x: type(x) in ast_types
def _child_lens(psi:STL, focus:Lens) -> STL_Generator:
if psi is None:
return
elif psi is stl.TOP or psi is stl.BOT:
return
elif isinstance(psi, stl.ast.Until):
yield from [focus.arg1, focus.arg2]
elif isinstance(psi, NaryOpSTL):
for j, _ in enumerate(psi.args):
yield focus.args[j]
else:
yield focus.arg
def ast_lens(phi:STL, bind:bool=True, *,
pred:Mapping[T, bool], focus_lens:Lens=None) -> Lens:
def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
if focus_lens is None:
focus_lens = lambda x: [lens()]
tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens)))
tl = lens().tuple_(*tls).each_()
return tl.bind(phi) if bind else tl
focus_lens = lambda _: [lens]
if pred is None:
pred = lambda _: False
l = lenses.bind(phi) if bind else lens
return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens))
def _ast_lens(phi:STL, pred, focus_lens) -> Lens:
if pred(phi):
yield from focus_lens(phi)
if phi is None or not phi.children:
return
def _ast_lens(phi, *, pred, focus=lens(), focus_lens):
psi = focus.get(state=phi)
ret_lens = [focus.add_lens(l) for l in focus_lens(psi)] if pred(psi) else []
if isinstance(psi, (LinEq, stl.ast.AtomicPred)):
return ret_lens
child_lenses = list(_child_lens(psi, focus=focus))
ret_lens += [_ast_lens(phi, pred=pred, focus=cl, focus_lens=focus_lens)
for cl in child_lenses]
return ret_lens
if phi is stl.TOP or phi is stl.BOT:
child_lenses = [lens]
elif isinstance(phi, stl.ast.Until):
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
elif isinstance(phi, NaryOpSTL):
child_lenses = [lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)]
else:
child_lenses = [lens.GetAttr('arg')]
for l in child_lenses:
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq))
AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred))
@ -78,7 +72,7 @@ def terms_lens(phi:STL, bind:bool=True) -> Lens:
def param_lens(phi:STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf):
return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
return [lens.const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
return ast_lens(phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym)
@ -93,7 +87,7 @@ def f_neg_or_canonical_form(phi:STL) -> STL:
if isinstance(phi, LinEq):
return phi
children = [f_neg_or_canonical_form(s) for s in phi.children()]
children = [f_neg_or_canonical_form(s) for s in phi.children]
if isinstance(phi, (And, G)):
children = [Neg(s) for s in children]
children = tuple(children)
@ -160,7 +154,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
return float(term.coeff)*x[term.id.name][t]
output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1]))
terms = lens(lineq).terms.each_().get_all()
terms = lens(lineq).Each().terms.Each().collect()
for t in times:
lhs = sum(eval_term(term, t) for term in terms)
output[t] = op_lookup[lineq.op](lhs, lineq.const)
@ -172,7 +166,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
def eval_lineqs(phi, x, times=None):
if times is None:
times = get_times(x)
lineqs = set(lineq_lens(phi).get_all())
lineqs = set(lineq_lens(phi).Each().collect())
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}