diff --git a/stl/ast.py b/stl/ast.py index 37d6a41..6ed256a 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -15,10 +15,13 @@ t_sym = Symbol('t', positive=True) def flatten_binary(phi, op, dropT, shortT): f = lambda x: x.args if isinstance(x, op) else [x] args = [arg for arg in phi.args if arg is not dropT] + if any(arg is shortT for arg in args): return shortT elif not args: return dropT + elif len(args) == 1: + return args[0] else: return op(tuple(fn.mapcat(f, phi.args))) diff --git a/stl/test_ast.py b/stl/test_ast.py index 86d7cd3..9beb3c7 100644 --- a/stl/test_ast.py +++ b/stl/test_ast.py @@ -7,6 +7,10 @@ class TestSTLAST(unittest.TestCase): phi = stl.parse("x") self.assertEqual(stl.TOP, stl.TOP | phi) self.assertEqual(stl.BOT, stl.BOT & phi) + self.assertEqual(stl.TOP, phi | stl.TOP) + self.assertEqual(stl.BOT, phi & stl.BOT) + self.assertEqual(phi, phi & stl.TOP) + self.assertEqual(phi, phi | stl.BOT) self.assertEqual(stl.TOP, stl.TOP & stl.TOP) self.assertEqual(stl.BOT, stl.BOT | stl.BOT) self.assertEqual(stl.TOP, stl.TOP | stl.BOT)