diff --git a/stl/test_utils.py b/stl/test_utils.py index a2c78fc..ec96ec9 100644 --- a/stl/test_utils.py +++ b/stl/test_utils.py @@ -29,13 +29,16 @@ class TestSTLUtils(unittest.TestCase): self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set()) self.assertEqual(phi, phi2) - def test_walk(self): - raise NotImplementedError - - - def test_type_pred(self): - raise NotImplementedError + @params(("x > 5", 1), ("~(x)", 2), ("(F[0,1](x)) & (~(G[0, 2](y)))", 6)) + def test_walk(self, phi_str, l): + self.assertEqual(l, len(list(stl.walk(stl.parse(phi_str))))) + @params(([], False, False),([int], True, False), ([int, bool], True, True)) + def test_type_pred(self, types, b1, b2): + pred = stl.utils.type_pred(*types) + self.assertFalse(pred(None)) + self.assertEqual(pred(1), b1) + self.assertEqual(pred(True), b2) def test_ast_lens(self): raise NotImplementedError diff --git a/stl/utils.py b/stl/utils.py index 0209c0a..6dbe234 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -12,11 +12,11 @@ from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg, AtomicPred) from stl.types import STL, STL_Generator, MTL -def walk(phi:STL, bfs:bool=False) -> STL_Generator: - """Walks Ast. Defaults to DFS unless BFS flag is set.""" - pop = deque.popleft if bfs else deque.pop +def walk(phi:STL) -> STL_Generator: + """DSF walk of the AST.""" + pop = deque.pop children = deque([phi]) - while len(children) != 0: + while len(children) > 0: node = pop(children) yield node children.extend(node.children())