generalize lineq lens + remove Nones from game_to_sl

This commit is contained in:
Marcell Vazquez-Chanlatte 2016-07-10 17:19:32 -07:00
parent c4cfda6e19
commit 1f41faef59
2 changed files with 21 additions and 7 deletions

View file

@ -1,4 +1,4 @@
from stl.stl import terms_lens, lineq_lens, walk, tree from stl.stl import terms_lens, lineq_lens, walk, tree, and_or_lens
from stl.stl import dt_sym, t_sym from stl.stl import dt_sym, t_sym
from stl.stl import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg from stl.stl import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg
from stl.parser import parse from stl.parser import parse

26
stl.py
View file

@ -105,16 +105,20 @@ def tree(stl):
def lineq_lens(phi:"STL", bind=True) -> lens: def lineq_lens(phi:"STL", bind=True) -> lens:
tls = list(fn.flatten(_lineq_lens(phi))) return ast_lens(phi, bind=bind, types={LinEq})
tl = lens().tuple_(*tls).each_()
return tl.bind(phi) if bind else tl
def terms_lens(phi:"STL", bind=True) -> lens: def terms_lens(phi:"STL", bind=True) -> lens:
return lineq_lens(phi, bind).terms.each_() return lineq_lens(phi, bind).terms.each_()
def and_or_lens(phi:"STL", bind=True) -> lens:
return ast_lens(phi, bind=bind, types={And, Or})
def _child_lens(psi, focus): def _child_lens(psi, focus):
if psi is None:
return
if isinstance(psi, NaryOpSTL): if isinstance(psi, NaryOpSTL):
for j, _ in enumerate(psi.args): for j, _ in enumerate(psi.args):
yield focus.args[j] yield focus.args[j]
@ -122,9 +126,19 @@ def _child_lens(psi, focus):
yield focus.arg yield focus.arg
def _lineq_lens(phi, focus=lens()): def ast_lens(phi:"STL", bind=True, *, types) -> lens:
tls = list(fn.flatten(_ast_lens(phi, types=types)))
tl = lens().tuple_(*tls).each_()
return tl.bind(phi) if bind else tl
def _ast_lens(phi, *, types, focus=lens()):
psi = focus.get(state=phi) psi = focus.get(state=phi)
ret_lens = [focus] if type(psi) in types else []
if isinstance(psi, LinEq): if isinstance(psi, LinEq):
return [focus] return ret_lens
child_lenses = list(_child_lens(psi, focus=focus)) child_lenses = list(_child_lens(psi, focus=focus))
return [_lineq_lens(phi, focus=cl) for cl in child_lenses] ret_lens += [_ast_lens(phi, types=types, focus=cl) for cl in child_lenses]
return ret_lens