added AST mixin for binary operators

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-01-03 15:54:20 -08:00
parent 13cc7abdf2
commit deaa13f13a

View file

@ -5,12 +5,32 @@
from collections import namedtuple, deque from collections import namedtuple, deque
from itertools import repeat from itertools import repeat
from enum import Enum from enum import Enum
import funcy as fn
from sympy import Symbol from sympy import Symbol
dt_sym = Symbol('dt', positive=True) dt_sym = Symbol('dt', positive=True)
t_sym = Symbol('t', positive=True) t_sym = Symbol('t', positive=True)
class AtomicPred(namedtuple("AP", ["id"])): def flatten_binary(phi):
t = type(phi)
f = lambda x: x.args if isinstance(x, t) else [x]
return t(tuple(fn.mapcat(f, phi.args)))
class AST(object):
def __or__(self, other):
return flatten_binary(Or((self, other)))
def __and__(self, other):
return flatten_binary(And((self, other)))
def __invert__(self):
return Neg(self)
class AtomicPred(namedtuple("AP", ["id"]), AST):
def __repr__(self): def __repr__(self):
return "{}".format(self.id) return "{}".format(self.id)
@ -18,7 +38,7 @@ class AtomicPred(namedtuple("AP", ["id"])):
return [] return []
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])): class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
def __repr__(self): def __repr__(self):
return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}" return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}"
@ -39,7 +59,7 @@ class Interval(namedtuple('I', ['lower', 'upper'])):
return [self.lower, self.upper] return [self.lower, self.upper]
class NaryOpSTL(namedtuple('NaryOp', ['args'])): class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
OP = "?" OP = "?"
def __repr__(self): def __repr__(self):
return f" {self.OP} ".join(f"({x})" for x in self.args) return f" {self.OP} ".join(f"({x})" for x in self.args)
@ -55,7 +75,7 @@ class And(NaryOpSTL):
OP = "" OP = ""
class ModalOp(namedtuple('ModalOp', ['interval', 'arg'])): class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST):
def __repr__(self): def __repr__(self):
return f"{self.OP}{self.interval}({self.arg})" return f"{self.OP}{self.interval}({self.arg})"
@ -70,7 +90,7 @@ class G(ModalOp):
OP = "" OP = ""
class Neg(namedtuple('Neg', ['arg'])): class Neg(namedtuple('Neg', ['arg']), AST):
def __repr__(self): def __repr__(self):
return f"¬({self.arg})" return f"¬({self.arg})"