From e3ff61e971723ab741c3edd1f154cca7e1dc1c00 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Sat, 8 Oct 2016 20:20:26 -0700 Subject: [PATCH] added lens to access all parameters in an AST --- parser.py | 6 ++++-- robustness.py | 1 - utils.py | 24 ++++++++++++++++++------ 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/parser.py b/parser.py index f04500f..8d3bc75 100644 --- a/parser.py +++ b/parser.py @@ -49,7 +49,7 @@ prime = "'" pm = "+" / "-" dt = "dt" -unbound = "?" id +unbound = id "?" id = ~r"[a-zA-z\d]+" const = ~r"[\+\-]?\d*(\.\d+)?" op = ">=" / "<=" / "<" / ">" / "=" @@ -76,7 +76,9 @@ class STLVisitor(NodeVisitor): def get_text(self, node, _): return node.text - visit_unbound = get_text + def visit_unbound(self, node, _): + return Symbol(node.text) + visit_op = get_text def unary_temp_op_visitor(self, _, children, op): diff --git a/robustness.py b/robustness.py index aaf5bbe..127395f 100644 --- a/robustness.py +++ b/robustness.py @@ -88,4 +88,3 @@ def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3): r, param = binsearch() #TODO # TODO: update val return val - diff --git a/utils.py b/utils.py index 1b0019e..d48cde5 100644 --- a/utils.py +++ b/utils.py @@ -2,8 +2,9 @@ from collections import deque from lenses import lens import funcy as fn +import sympy -from stl.ast import LinEq, And, Or, NaryOpSTL +from stl.ast import LinEq, And, Or, NaryOpSTL, F, G, Interval def walk(stl, bfs=False): """Walks Ast. Defaults to DFS unless BFS flag is set.""" @@ -34,21 +35,24 @@ def _child_lens(psi, focus): yield focus.arg -def ast_lens(phi:"STL", bind=True, *, pred) -> lens: - tls = list(fn.flatten(_ast_lens(phi, pred=pred))) +def ast_lens(phi:"STL", bind=True, *, pred, focus_lens=None) -> lens: + if focus_lens is None: + focus_lens = lambda x: [lens()] + tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens))) tl = lens().tuple_(*tls).each_() return tl.bind(phi) if bind else tl -def _ast_lens(phi, *, pred, focus=lens()): +def _ast_lens(phi, *, pred, focus=lens(), focus_lens): psi = focus.get(state=phi) - ret_lens = [focus] if pred(psi) else [] + ret_lens = [focus.add_lens(l) for l in focus_lens(psi)] if pred(psi) else [] if isinstance(psi, LinEq): return ret_lens child_lenses = list(_child_lens(psi, focus=focus)) - ret_lens += [_ast_lens(phi, pred=pred, focus=cl) for cl in child_lenses] + ret_lens += [_ast_lens(phi, pred=pred, focus=cl, focus_lens=focus_lens) + for cl in child_lenses] return ret_lens @@ -57,3 +61,11 @@ and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or)) def terms_lens(phi:"STL", bind=True) -> lens: return lineq_lens(phi, bind).terms.each_() + + +def param_lens(phi): + is_sym = lambda x: isinstance(x, sympy.Symbol) + def focus_lens(leaf): + return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] + + return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens).filter_(is_sym)