make ast objects easier to work with

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-31 18:29:31 -07:00
parent 2640728288
commit 7d8cf78222
5 changed files with 131 additions and 128 deletions

View file

@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-
# TODO: supress + given a + (-b). i.e. want a - b
from collections import namedtuple
from collections import deque, namedtuple
from functools import lru_cache
import funcy as fn
from sympy import Symbol
dt_sym = Symbol('dt', positive=True)
t_sym = Symbol('t', positive=True)
import lenses
from lenses import lens
def flatten_binary(phi, op, dropT, shortT):
@ -42,6 +41,45 @@ class AST(object):
def children(self):
return set()
def walk(self):
"""Walk of the AST."""
pop = deque.pop
children = deque([self])
while len(children) > 0:
node = pop(children)
yield node
children.extend(node.children)
@property
def params(self):
def get_params(leaf):
if isinstance(leaf, ModalOp):
if isinstance(leaf.interval[0], Param):
yield leaf.interval[0]
if isinstance(leaf.interval[1], Param):
yield leaf.interval[1]
elif isinstance(leaf, LinEq):
if isinstance(leaf.const, Param):
yield leaf.const
return set(fn.mapcat(get_params, self.walk()))
def set_params(self, val):
phi = param_lens(self)
return phi.modify(lambda x: float(val.get(x, val.get(str(x), x))))
@property
def terms(self):
return set(terms_lens(self).Each().collect())
@property
def lineqs(self):
return set(lineq_lens(self).Each().collect())
@property
def atomic_predicates(self):
return set(AP_lens(self).Each().collect())
class _Top(AST):
__slots__ = ()
@ -234,3 +272,72 @@ class Param(namedtuple('Param', ['name']), AST):
def __hash__(self):
# TODO: compute hash based on contents
return hash(repr(self))
def ast_lens(phi, bind=True, *, pred=None, focus_lens=None, getter=False):
if focus_lens is None:
def focus_lens(_):
return [lens]
if pred is None:
def pred(_):
return False
child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens)
phi = lenses.bind(phi) if bind else lens
return (phi.Tuple if getter else phi.Fork)(*child_lenses)
def _ast_lens(phi, pred, focus_lens):
if pred(phi):
yield from focus_lens(phi)
if phi is None or not phi.children:
return
if phi is TOP or phi is BOT:
child_lenses = [lens]
elif isinstance(phi, Until):
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
elif isinstance(phi, NaryOpSTL):
child_lenses = [
lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)
]
else:
child_lenses = [lens.GetAttr('arg')]
for l in child_lenses:
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
@lru_cache()
def param_lens(phi, *, getter=False):
def focus_lens(leaf):
candidates = [lens.const] if isinstance(leaf, LinEq) else [
lens.GetAttr('interval')[0],
lens.GetAttr('interval')[1]
]
return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens(
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
def vars_in_phi(phi):
focus = terms_lens(phi)
return set(focus.tuple_(lens.id, lens.time).get_all())
def type_pred(*args):
ast_types = set(args)
return lambda x: type(x) in ast_types
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True)
AP_lens = fn.partial(ast_lens, pred=type_pred(AtomicPred), getter=True)
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True)
def terms_lens(phi, bind=True):
return lineq_lens(phi, bind).Each().terms.Each()