diff --git a/stl/__init__.py b/stl/__init__.py index e0b5059..a3a8c45 100644 --- a/stl/__init__.py +++ b/stl/__init__.py @@ -1,6 +1,6 @@ from stl.utils import terms_lens, lineq_lens, walk, and_or_lens from stl.utils import alw, env, andf, orf -from stl.ast import dt_sym, t_sym +from stl.ast import dt_sym, t_sym, TOP, BOT from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred from stl.parser import parse from stl.synth import lex_param_project diff --git a/stl/ast.py b/stl/ast.py index 419ff26..2a521c4 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -12,23 +12,48 @@ from sympy import Symbol dt_sym = Symbol('dt', positive=True) t_sym = Symbol('t', positive=True) -def flatten_binary(phi): - t = type(phi) - f = lambda x: x.args if isinstance(x, t) else [x] - return t(tuple(fn.mapcat(f, phi.args))) +def flatten_binary(phi, op, dropT, shortT): + f = lambda x: x.args if isinstance(x, op) else [x] + args = [arg for arg in phi.args if not isinstance(arg, type(dropT))] + + if any(isinstance(arg, type(shortT)) for arg in args): + return shortT + elif not args: + return dropT + else: + return op(tuple(fn.mapcat(f, phi.args))) class AST(object): def __or__(self, other): - return flatten_binary(Or((self, other))) + return flatten_binary(Or((self, other)), Or, BOT, TOP) def __and__(self, other): - return flatten_binary(And((self, other))) + return flatten_binary(And((self, other)), And, TOP, BOT) def __invert__(self): return Neg(self) +class _Top(AST): + def __repr__(self): + return "⊤" + + def __invert__(self): + return Bot() + + +class _Bot(AST): + def __repr__(self): + return "⊥" + + def __invert__(self): + return Top() + +TOP = _Top() +BOT = _Bot() + + class AtomicPred(namedtuple("AP", ["id", "time"]), AST): def __repr__(self): return f"{self.id}[{self.time}]" diff --git a/stl/test_ast.py b/stl/test_ast.py new file mode 100644 index 0000000..accbb9d --- /dev/null +++ b/stl/test_ast.py @@ -0,0 +1,13 @@ +import stl +from nose2.tools import params +import unittest + +class TestSTLAST(unittest.TestCase): + def test_and(self): + phi = stl.parse("x") + self.assertEqual(stl.TOP, stl.TOP | phi) + self.assertEqual(stl.BOT, stl.BOT & phi) + self.assertEqual(stl.TOP, stl.TOP & stl.TOP) + self.assertEqual(stl.BOT, stl.BOT | stl.BOT) + self.assertEqual(stl.TOP, stl.TOP | stl.BOT) + self.assertEqual(stl.BOT, stl.TOP & stl.BOT)