diff --git a/stl/ast.py b/stl/ast.py index 0d1c40f..e16a934 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -41,7 +41,7 @@ class AST(object): @property def children(self): - return set() + return tuple() def walk(self): """Walk of the AST.""" @@ -123,7 +123,7 @@ class AtomicPred(namedtuple("AP", ["id"]), AST): @property def children(self): - return set() + return tuple() class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST): @@ -134,7 +134,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST): @property def children(self): - return set() + return tuple() def __hash__(self): # TODO: compute hash based on contents @@ -171,7 +171,7 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST): @property def children(self): - return set(self.args) + return tuple(self.args) class Or(NaryOpSTL): @@ -203,7 +203,7 @@ class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST): @property def children(self): - return {self.arg} + return (self.arg,) class F(ModalOp): @@ -232,7 +232,7 @@ class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST): @property def children(self): - return {self.arg1, self.arg2} + return (self.arg1, self.arg2) def __hash__(self): # TODO: compute hash based on contents @@ -247,7 +247,7 @@ class Neg(namedtuple('Neg', ['arg']), AST): @property def children(self): - return {self.arg} + return (self.arg,) def __hash__(self): # TODO: compute hash based on contents @@ -262,7 +262,7 @@ class Next(namedtuple('Next', ['arg']), AST): @property def children(self): - return {self.arg} + return (self.arg,) def __hash__(self): # TODO: compute hash based on contents diff --git a/stl/test_utils.py b/stl/test_utils.py new file mode 100644 index 0000000..250ad54 --- /dev/null +++ b/stl/test_utils.py @@ -0,0 +1,13 @@ +import stl +from stl.hypothesis import SignalTemporalLogicStrategy + +from hypothesis import given + + +@given(SignalTemporalLogicStrategy) +def test_f_neg_or_canonical_form(phi): + phi2 = stl.utils.f_neg_or_canonical_form(phi) + phi3 = stl.utils.f_neg_or_canonical_form(phi2) + assert phi2 == phi3 + assert not any( + isinstance(x, (stl.ast.G, stl.ast.And)) for x in phi2.walk()) diff --git a/stl/utils.py b/stl/utils.py index f49b960..6e46efd 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -5,7 +5,9 @@ import traces from lenses import bind import stl.ast -from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens) +from stl.ast import (And, F, G, Interval, LinEq, Neg, + Or, AP_lens, Next, Until, AtomicPred, + _Top, _Bot) from stl.types import STL @@ -13,24 +15,28 @@ oo = float('inf') def f_neg_or_canonical_form(phi: STL) -> STL: - if isinstance(phi, LinEq): + if isinstance(phi, (LinEq, AtomicPred, _Top, _Bot)): return phi children = [f_neg_or_canonical_form(s) for s in phi.children] if isinstance(phi, (And, G)): children = [Neg(s) for s in children] - children = tuple(children) + children = tuple(sorted(children, key=str)) if isinstance(phi, Or): return Or(children) elif isinstance(phi, And): return Neg(Or(children)) elif isinstance(phi, Neg): - return Neg(children[0]) + return Neg(*children) + elif isinstance(phi, Next): + return Next(*children) + elif isinstance(phi, Until): + return Until(*children) elif isinstance(phi, F): - return F(phi.interval, children[0]) + return F(phi.interval, *children) elif isinstance(phi, G): - return Neg(F(phi.interval, children[0])) + return Neg(F(phi.interval, *children)) else: raise NotImplementedError