support overwritting default interval

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-02-25 23:48:34 -08:00
parent 4c5924782d
commit 00eec50f79

View file

@ -67,10 +67,12 @@ _ = ~r"\s"+
__ = ~r"\s"*
EOL = "\\n"
''')
default_interval = ast.Interval(0, float('inf'))
class STLVisitor(NodeVisitor):
def __init__(self, H=float('inf')):
super().__init__()
self.default_interval = ast.Interval(0, H)
def generic_visit(self, _, children):
return children
@ -94,7 +96,7 @@ class STLVisitor(NodeVisitor):
def unary_temp_op_visitor(self, _, children, op):
_, i, phi = children
i = default_interval if not i else i[0]
i = self.default_interval if not i else i[0]
return op(i, phi)
def binop_visitor(self, _, children, op):
@ -117,7 +119,7 @@ class STLVisitor(NodeVisitor):
def visit_until(self, _, children):
_, _, phi1, _, _, i, _, phi2, *_ = children
i = default_interval if not i else i[0]
i = self.default_interval if not i else i[0]
return ast.Until(i, phi1, phi2)
def visit_id(self, name, _):
@ -178,5 +180,5 @@ class STLVisitor(NodeVisitor):
return ast.Neg(children[1])
def parse(stl_str:str, rule:str="phi") -> "STL":
return STLVisitor().visit(STL_GRAMMAR[rule].parse(stl_str))
def parse(stl_str:str, rule:str="phi", H=float('inf')) -> "STL":
return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str))