diff --git a/stl/ast.py b/stl/ast.py index 08ce551..419ff26 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -29,7 +29,6 @@ class AST(object): return Neg(self) - class AtomicPred(namedtuple("AP", ["id", "time"]), AST): def __repr__(self): return f"{self.id}[{self.time}]" @@ -45,6 +44,10 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST): def children(self): return [] + def __hash__(self): + # TODO: compute hash based on contents + return hash(repr(self)) + class Var(namedtuple("Var", ["coeff", "id", "time"])): def __repr__(self): @@ -71,9 +74,17 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST): class Or(NaryOpSTL): OP = "∨" + def __hash__(self): + # TODO: compute hash based on contents + return hash(repr(self)) + class And(NaryOpSTL): OP = "∧" + def __hash__(self): + # TODO: compute hash based on contents + return hash(repr(self)) + class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST): def __repr__(self): @@ -86,9 +97,17 @@ class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST): class F(ModalOp): OP = "◇" + def __hash__(self): + # TODO: compute hash based on contents + return hash(repr(self)) + class G(ModalOp): OP = "□" + def __hash__(self): + # TODO: compute hash based on contents + return hash(repr(self)) + class Until(namedtuple('ModalOp', ['interval', 'arg1', 'arg2']), AST): def __repr__(self): @@ -97,6 +116,10 @@ class Until(namedtuple('ModalOp', ['interval', 'arg1', 'arg2']), AST): def children(self): return [self.arg1, self.arg2] + def __hash__(self): + # TODO: compute hash based on contents + return hash(repr(self)) + class Neg(namedtuple('Neg', ['arg']), AST): def __repr__(self): @@ -104,3 +127,7 @@ class Neg(namedtuple('Neg', ['arg']), AST): def children(self): return [self.arg] + + def __hash__(self): + # TODO: compute hash based on contents + return hash(repr(self)) diff --git a/stl/parser.py b/stl/parser.py index 7577c24..157af46 100644 --- a/stl/parser.py +++ b/stl/parser.py @@ -22,7 +22,7 @@ from stl import ast from stl.utils import implies, xor, iff STL_GRAMMAR = Grammar(u''' -phi = (g / f / until / lineq / AP / or / and / implies / xor / iff / neg / paren_phi) +phi = (neg / g / f / until / lineq / AP / or / and / implies / xor / iff / paren_phi) paren_phi = "(" __ phi __ ")" @@ -32,11 +32,11 @@ implies = paren_phi _ ("→" / "->") _ (and / paren_phi) iff = paren_phi _ ("⇔" / "<->" / "iff") _ (and / paren_phi) xor = paren_phi _ ("⊕" / "^" / "xor") _ (and / paren_phi) -neg = ("~" / "¬") paren_phi +neg = ("~" / "¬") phi -f = F interval phi -g = G interval phi -until = "(" __ phi _ U interval _ phi __ ")" +f = F interval? phi +g = G interval? phi +until = "(" __ phi _ U interval? _ phi __ ")" F = "F" / "◇" G = "G" / "□" @@ -68,6 +68,8 @@ _ = ~r"\s"+ __ = ~r"\s"* EOL = "\\n" ''') + +default_interval = ast.Interval(0, float('inf')) class STLVisitor(NodeVisitor): def generic_visit(self, _, children): @@ -92,8 +94,9 @@ class STLVisitor(NodeVisitor): visit_op = get_text def unary_temp_op_visitor(self, _, children, op): - _, interval, phi = children - return op(interval, phi) + _, i, phi = children + i = default_interval if not i else i[0] + return op(i, phi) def binop_visitor(self, _, children, op): phi1, _, _, _, (phi2,) = children @@ -115,6 +118,7 @@ class STLVisitor(NodeVisitor): def visit_until(self, _, children): _, _, phi1, _, _, i, _, phi2, *_ = children + i = default_interval if not i else i[0] return ast.Until(i, phi1, phi2) def visit_id(self, name, _): @@ -145,7 +149,6 @@ class STLVisitor(NodeVisitor): c = coeffs[0] if coeffs else Number(1) return ast.Var(coeff=c, id=iden, time=time) - def visit_coeff(self, _, children): dt, coeff, *_ = children dt = dt[0][0] if dt else Number(1)