diff --git a/stl/parser.py b/stl/parser.py index 20e8d66..4c1d2fa 100644 --- a/stl/parser.py +++ b/stl/parser.py @@ -67,10 +67,12 @@ _ = ~r"\s"+ __ = ~r"\s"* EOL = "\\n" ''') - -default_interval = ast.Interval(0, float('inf')) class STLVisitor(NodeVisitor): + def __init__(self, H=float('inf')): + super().__init__() + self.default_interval = ast.Interval(0, H) + def generic_visit(self, _, children): return children @@ -94,7 +96,7 @@ class STLVisitor(NodeVisitor): def unary_temp_op_visitor(self, _, children, op): _, i, phi = children - i = default_interval if not i else i[0] + i = self.default_interval if not i else i[0] return op(i, phi) def binop_visitor(self, _, children, op): @@ -117,7 +119,7 @@ class STLVisitor(NodeVisitor): def visit_until(self, _, children): _, _, phi1, _, _, i, _, phi2, *_ = children - i = default_interval if not i else i[0] + i = self.default_interval if not i else i[0] return ast.Until(i, phi1, phi2) def visit_id(self, name, _): @@ -178,5 +180,5 @@ class STLVisitor(NodeVisitor): return ast.Neg(children[1]) -def parse(stl_str:str, rule:str="phi") -> "STL": - return STLVisitor().visit(STL_GRAMMAR[rule].parse(stl_str)) +def parse(stl_str:str, rule:str="phi", H=float('inf')) -> "STL": + return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str))