From cc50cdc49e896fd00f0dcae74d3377764a919778 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Mon, 11 Jul 2016 19:23:24 -0700 Subject: [PATCH] bug fixes + generalize ast_lens to any node predicate --- ast.py | 30 +++++++++++++++--------------- parser.py | 22 +++++++++++----------- 2 files changed, 26 insertions(+), 26 deletions(-) diff --git a/ast.py b/ast.py index ebb4068..2e6f660 100644 --- a/ast.py +++ b/ast.py @@ -104,16 +104,9 @@ def tree(stl): return {x:set(x.children()) for x in walk(stl) if x.children()} -def lineq_lens(phi:"STL", bind=True) -> lens: - return ast_lens(phi, bind=bind, types={LinEq}) - - -def terms_lens(phi:"STL", bind=True) -> lens: - return lineq_lens(phi, bind).terms.each_() - - -def and_or_lens(phi:"STL", bind=True) -> lens: - return ast_lens(phi, bind=bind, types={And, Or}) +def type_pred(*args): + ast_types = set(args) + return lambda x: type(x) in ast_types def _child_lens(psi, focus): @@ -126,19 +119,26 @@ def _child_lens(psi, focus): yield focus.arg -def ast_lens(phi:"STL", bind=True, *, types) -> lens: - tls = list(fn.flatten(_ast_lens(phi, types=types))) +def ast_lens(phi:"STL", bind=True, *, pred) -> lens: + tls = list(fn.flatten(_ast_lens(phi, pred=pred))) tl = lens().tuple_(*tls).each_() return tl.bind(phi) if bind else tl -def _ast_lens(phi, *, types, focus=lens()): +def _ast_lens(phi, *, pred, focus=lens()): psi = focus.get(state=phi) - ret_lens = [focus] if type(psi) in types else [] + ret_lens = [focus] if pred(psi) else [] if isinstance(psi, LinEq): return ret_lens child_lenses = list(_child_lens(psi, focus=focus)) - ret_lens += [_ast_lens(phi, types=types, focus=cl) for cl in child_lenses] + ret_lens += [_ast_lens(phi, pred=pred, focus=cl) for cl in child_lenses] return ret_lens + + +lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq)) +and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or)) + +def terms_lens(phi:"STL", bind=True) -> lens: + return lineq_lens(phi, bind).terms.each_() diff --git a/parser.py b/parser.py index 79ec1da..9902c18 100644 --- a/parser.py +++ b/parser.py @@ -19,7 +19,7 @@ from lenses import lens from sympy import Symbol, Number -from stl import stl +from stl import ast STL_GRAMMAR = Grammar(u''' phi = (g / f / lineq / or / and / paren_phi) @@ -72,7 +72,7 @@ class STLVisitor(NodeVisitor): def visit_interval(self, _, children): _, _, left, _, _, _, right, _, _ = children - return stl.Interval(left[0], right[0]) + return ast.Interval(left[0], right[0]) def get_text(self, node, _): return node.text @@ -90,10 +90,10 @@ class STLVisitor(NodeVisitor): argR = list(phi2.args) if isinstance(phi2, op) else [phi2] return op(tuple(argL + argR)) - visit_f = partialmethod(unary_temp_op_visitor, op=stl.F) - visit_g = partialmethod(unary_temp_op_visitor, op=stl.G) - visit_or = partialmethod(binop_visitor, op=stl.Or) - visit_and = partialmethod(binop_visitor, op=stl.And) + visit_f = partialmethod(unary_temp_op_visitor, op=ast.F) + visit_g = partialmethod(unary_temp_op_visitor, op=ast.G) + visit_or = partialmethod(binop_visitor, op=ast.Or) + visit_and = partialmethod(binop_visitor, op=ast.And) def visit_id(self, name, _): return Symbol(name.text) @@ -102,7 +102,7 @@ class STLVisitor(NodeVisitor): iden, time_node = children time_node = list(flatten(time_node)) - time = time_node[0] if len(time_node) > 0 else stl.t_sym + time = time_node[0] if len(time_node) > 0 else ast.t_sym return iden, time @@ -110,18 +110,18 @@ class STLVisitor(NodeVisitor): return children[3]* children[5] def visit_prime(self, *_): - return stl.t_sym - stl.dt_sym + return ast.t_sym - ast.dt_sym def visit_const(self, const, children): return float(const.text) def visit_dt(self, *_): - return stl.dt_sym + return ast.dt_sym def visit_term(self, _, children): coeffs, (iden, time) = children c = coeffs[0] if coeffs else Number(1) - return stl.Var(coeff=c, id=iden, time=time) + return ast.Var(coeff=c, id=iden, time=time) def visit_coeff(self, _, children): @@ -139,7 +139,7 @@ class STLVisitor(NodeVisitor): def visit_lineq(self, _, children): terms, _1, op, _2, const = children - return stl.LinEq(tuple(terms), op, const[0]) + return ast.LinEq(tuple(terms), op, const[0]) def visit_pm(self, node, _): return Number(1) if node.text == "+" else Number(-1)