use lenses.Recur instead of ast_lens

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-11-30 22:46:12 -08:00
parent 799a9df006
commit 8f4e511326
2 changed files with 9 additions and 49 deletions

View file

@ -2,7 +2,7 @@
-e git://github.com/mvcisback/multidim-threshold@master#egg=multidim-threshold -e git://github.com/mvcisback/multidim-threshold@master#egg=multidim-threshold
bitarray==0.8.1 bitarray==0.8.1
funcy==1.9.1 funcy==1.9.1
lenses==0.3.0 lenses==0.4.0
pandas==0.19.2 pandas==0.19.2
parsimonious==0.7.0 parsimonious==0.7.0
sympy==1.0 sympy==1.0

View file

@ -4,7 +4,6 @@ from collections import deque, namedtuple
from functools import lru_cache from functools import lru_cache
import funcy as fn import funcy as fn
import lenses
from lenses import lens, bind from lenses import lens, bind
@ -71,11 +70,11 @@ class AST(object):
@property @property
def lineqs(self): def lineqs(self):
return set(lineq_lens(self).Each().collect()) return set(lineq_lens.collect()(self))
@property @property
def atomic_predicates(self): def atomic_predicates(self):
return set(AP_lens(self).Each().collect()) return set(AP_lens.collect()(self))
@property @property
def var_names(self): def var_names(self):
@ -86,11 +85,11 @@ class AST(object):
def inline_context(self, context): def inline_context(self, context):
phi, phi2 = self, None phi, phi2 = self, None
def update(aps): def update(ap):
return tuple(context.get(ap, ap) for ap in aps) return context.get(ap, ap)
while phi2 != phi: while phi2 != phi:
phi2, phi = phi, AP_lens(phi).modify(update) phi2, phi = phi, AP_lens.modify(update)(phi)
return phi return phi
@ -292,47 +291,9 @@ class Param(namedtuple('Param', ['name']), AST):
return hash(repr(self)) return hash(repr(self))
def ast_lens(phi,
bind=True,
*,
pred=lambda _: False,
focus_lens=lambda _: [lens],
getter=False):
child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens)
phi = lenses.bind(phi) if bind else lens
return (phi.Tuple if getter else phi.Fork)(*child_lenses)
def _ast_lens(phi, pred, focus_lens):
if pred(phi):
yield from focus_lens(phi)
if phi is None or not phi.children:
return
if isinstance(phi, Until):
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
elif isinstance(phi, NaryOpSTL):
child_lenses = [
lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)
]
else:
child_lenses = [lens.GetAttr('arg')]
for l in child_lenses:
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
@lru_cache() @lru_cache()
def param_lens(phi, *, getter=False): def param_lens(phi, *, getter=False):
def focus_lens(leaf): return bind(phi).Recur(Param)
candidates = [lens.const] if isinstance(leaf, LinEq) else [
lens.GetAttr('interval')[0],
lens.GetAttr('interval')[1]
]
return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens(
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
def type_pred(*args): def type_pred(*args):
@ -340,6 +301,5 @@ def type_pred(*args):
return lambda x: type(x) in ast_types return lambda x: type(x) in ast_types
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True) lineq_lens = lens.Recur(LinEq)
AP_lens = fn.partial(ast_lens, pred=type_pred(AtomicPred), getter=True) AP_lens = lens.Recur(AtomicPred)
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True)