added AST mixin for binary operators
This commit is contained in:
parent
13cc7abdf2
commit
deaa13f13a
1 changed files with 25 additions and 5 deletions
30
stl/ast.py
30
stl/ast.py
|
|
@ -5,12 +5,32 @@
|
||||||
from collections import namedtuple, deque
|
from collections import namedtuple, deque
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
|
import funcy as fn
|
||||||
from sympy import Symbol
|
from sympy import Symbol
|
||||||
|
|
||||||
dt_sym = Symbol('dt', positive=True)
|
dt_sym = Symbol('dt', positive=True)
|
||||||
t_sym = Symbol('t', positive=True)
|
t_sym = Symbol('t', positive=True)
|
||||||
|
|
||||||
class AtomicPred(namedtuple("AP", ["id"])):
|
def flatten_binary(phi):
|
||||||
|
t = type(phi)
|
||||||
|
f = lambda x: x.args if isinstance(x, t) else [x]
|
||||||
|
return t(tuple(fn.mapcat(f, phi.args)))
|
||||||
|
|
||||||
|
|
||||||
|
class AST(object):
|
||||||
|
def __or__(self, other):
|
||||||
|
return flatten_binary(Or((self, other)))
|
||||||
|
|
||||||
|
def __and__(self, other):
|
||||||
|
return flatten_binary(And((self, other)))
|
||||||
|
|
||||||
|
def __invert__(self):
|
||||||
|
return Neg(self)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class AtomicPred(namedtuple("AP", ["id"]), AST):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "{}".format(self.id)
|
return "{}".format(self.id)
|
||||||
|
|
||||||
|
|
@ -18,7 +38,7 @@ class AtomicPred(namedtuple("AP", ["id"])):
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
|
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}"
|
return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}"
|
||||||
|
|
||||||
|
|
@ -39,7 +59,7 @@ class Interval(namedtuple('I', ['lower', 'upper'])):
|
||||||
return [self.lower, self.upper]
|
return [self.lower, self.upper]
|
||||||
|
|
||||||
|
|
||||||
class NaryOpSTL(namedtuple('NaryOp', ['args'])):
|
class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
|
||||||
OP = "?"
|
OP = "?"
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f" {self.OP} ".join(f"({x})" for x in self.args)
|
return f" {self.OP} ".join(f"({x})" for x in self.args)
|
||||||
|
|
@ -55,7 +75,7 @@ class And(NaryOpSTL):
|
||||||
OP = "∧"
|
OP = "∧"
|
||||||
|
|
||||||
|
|
||||||
class ModalOp(namedtuple('ModalOp', ['interval', 'arg'])):
|
class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"{self.OP}{self.interval}({self.arg})"
|
return f"{self.OP}{self.interval}({self.arg})"
|
||||||
|
|
||||||
|
|
@ -70,7 +90,7 @@ class G(ModalOp):
|
||||||
OP = "□"
|
OP = "□"
|
||||||
|
|
||||||
|
|
||||||
class Neg(namedtuple('Neg', ['arg'])):
|
class Neg(namedtuple('Neg', ['arg']), AST):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"¬({self.arg})"
|
return f"¬({self.arg})"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue