From eda63fd6f032b35cfba3e9af497021bb3afc38ca Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Thu, 26 Oct 2017 16:14:45 -0700 Subject: [PATCH] Fixed PSTL construction --- stl/ast.py | 20 +++++++++++++++++++- stl/parser.py | 32 ++++++++++++-------------------- stl/utils.py | 16 +++++++--------- 3 files changed, 38 insertions(+), 30 deletions(-) diff --git a/stl/ast.py b/stl/ast.py index a91f175..70d25d0 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -97,7 +97,13 @@ class Var(namedtuple("Var", ["coeff", "id"])): __slots__ = () def __repr__(self): - return f"{self.coeff}*{self.id}" + if self.coeff == -1: + coeff_str = "-" + elif self.coeff == +1: + coeff_str = "" + else: + coeff_str = f"{self.coeff}*" + return f"{coeff_str}{self.id}" class Interval(namedtuple('I', ['lower', 'upper'])): @@ -217,3 +223,15 @@ class Next(namedtuple('Next', ['arg']), AST): def __hash__(self): # TODO: compute hash based on contents return hash(repr(self)) + + +class Param(namedtuple('Param', ['name']), AST): + __slots__ = () + + def __repr__(self): + return self.name + + + def __hash__(self): + # TODO: compute hash based on contents + return hash(repr(self)) diff --git a/stl/parser.py b/stl/parser.py index 3689973..7a5dada 100644 --- a/stl/parser.py +++ b/stl/parser.py @@ -12,8 +12,6 @@ from parsimonious import Grammar, NodeVisitor from funcy import flatten from lenses import lens -from sympy import Symbol, Number - from stl import ast from stl.utils import implies, xor, iff, env, alw @@ -45,8 +43,7 @@ interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]" const_or_unbound = unbound / "inf" / const lineq = terms _ op _ const_or_unbound -term = coeff? var -coeff = ((dt __ "*" __)? const __ "*" __) / (dt __ "*") +term = const? var terms = (term __ pm __ terms) / term var = id @@ -56,7 +53,7 @@ pm = "+" / "-" dt = "dt" unbound = id "?" id = ~r"[a-zA-z\d]+" -const = ~r"[\+\-]?\d*(\.\d+)?" +const = ~r"[-+]?\d*\.\d+|\d+" op = ">=" / "<=" / "<" / ">" / "=" _ = ~r"\s"+ __ = ~r"\s"* @@ -64,8 +61,10 @@ EOL = "\\n" ''') +oo = float('inf') + class STLVisitor(NodeVisitor): - def __init__(self, H=float('inf')): + def __init__(self, H=oo): super().__init__() self.default_interval = ast.Interval(0.0, H) @@ -92,7 +91,7 @@ class STLVisitor(NodeVisitor): return node.text def visit_unbound(self, node, _): - return Symbol(node.text) + return ast.Param(node.text) visit_op = get_text @@ -128,23 +127,16 @@ class STLVisitor(NodeVisitor): return env(psi, lo=lo, hi=hi) & alw(ast.Until(phi, psi), lo=0, hi=lo) def visit_id(self, name, _): - return Symbol(name.text) + return name.text def visit_const(self, const, children): return float(const.text) def visit_term(self, _, children): - coeffs, (iden, time) = children - c = coeffs[0] if coeffs else Number(1) - return ast.Var(coeff=c, id=iden, time=time) + coeffs, iden = children + c = coeffs[0] if coeffs else 1 + return ast.Var(coeff=c, id=iden) - def visit_coeff(self, _, children): - dt, coeff, *_ = children[0] - if not isinstance(dt, Symbol): - dt = dt[0][0] if dt else Number(1) - return dt * coeff - else: - return dt def visit_terms(self, _, children): if isinstance(children[0], list): @@ -159,7 +151,7 @@ class STLVisitor(NodeVisitor): return ast.LinEq(tuple(terms), op, const[0]) def visit_pm(self, node, _): - return Number(1) if node.text == "+" else Number(-1) + return 1 if node.text == "+" else -1 def visit_AP(self, *args): return ast.AtomicPred(self.visit_id(*args)) @@ -171,5 +163,5 @@ class STLVisitor(NodeVisitor): return ast.Next(children[1]) -def parse(stl_str: str, rule: str = "phi", H=float('inf')) -> "STL": +def parse(stl_str: str, rule: str = "phi", H=oo) -> "STL": return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str)) diff --git a/stl/utils.py b/stl/utils.py index 41d1ce7..52979a6 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -11,7 +11,7 @@ import traces import stl.ast from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg, - AtomicPred) + AtomicPred, Param, AST) from stl.types import STL, STL_Generator, MTL Lens = TypeVar('Lens') @@ -43,7 +43,7 @@ def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None) -> Lens: if pred is None: pred = lambda _: False l = lenses.bind(phi) if bind else lens - return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens)) + return l.Fork(*_ast_lens(phi, pred=pred, focus_lens=focus_lens)) def _ast_lens(phi: STL, pred, focus_lens) -> Lens: @@ -77,22 +77,20 @@ def terms_lens(phi: STL, bind: bool = True) -> Lens: def param_lens(phi: STL) -> Lens: - is_sym = lambda x: isinstance(x, sympy.Symbol) - def focus_lens(leaf): - return [lens.const] if isinstance(leaf, LinEq) else [ + candidates = [lens.const] if isinstance(leaf, LinEq) else [ lens.GetAttr('interval')[0], lens.GetAttr('interval')[1] ] + return (x for x in candidates if isinstance(x.get()(leaf), Param)) return ast_lens( phi, pred=type_pred(LinEq, F, G), - focus_lens=focus_lens).filter_(is_sym) + focus_lens=focus_lens) -def set_params(stl_or_lens, val) -> STL: - l = stl_or_lens if isinstance(stl_or_lens, - Lens) else param_lens(stl_or_lens) +def set_params(phi, val) -> STL: + l = param_lens(phi) if isinstance(phi, AST) else phi return l.modify(lambda x: float(val.get(x, val.get(str(x), x))))