343 lines
7.5 KiB
Python
343 lines
7.5 KiB
Python
# -*- coding: utf-8 -*-
|
||
# TODO: supress + given a + (-b). i.e. want a - b
|
||
|
||
from collections import deque, namedtuple
|
||
from functools import lru_cache
|
||
|
||
import funcy as fn
|
||
import lenses
|
||
from lenses import lens
|
||
|
||
|
||
def flatten_binary(phi, op, dropT, shortT):
|
||
def f(x):
|
||
return x.args if isinstance(x, op) else [x]
|
||
|
||
args = [arg for arg in phi.args if arg is not dropT]
|
||
|
||
if any(arg is shortT for arg in args):
|
||
return shortT
|
||
elif not args:
|
||
return dropT
|
||
elif len(args) == 1:
|
||
return args[0]
|
||
else:
|
||
return op(tuple(fn.mapcat(f, phi.args)))
|
||
|
||
|
||
class AST(object):
|
||
__slots__ = ()
|
||
|
||
def __or__(self, other):
|
||
return flatten_binary(Or((self, other)), Or, BOT, TOP)
|
||
|
||
def __and__(self, other):
|
||
return flatten_binary(And((self, other)), And, TOP, BOT)
|
||
|
||
def __invert__(self):
|
||
return Neg(self)
|
||
|
||
@property
|
||
def children(self):
|
||
return set()
|
||
|
||
def walk(self):
|
||
"""Walk of the AST."""
|
||
pop = deque.pop
|
||
children = deque([self])
|
||
while len(children) > 0:
|
||
node = pop(children)
|
||
yield node
|
||
children.extend(node.children)
|
||
|
||
@property
|
||
def params(self):
|
||
def get_params(leaf):
|
||
if isinstance(leaf, ModalOp):
|
||
if isinstance(leaf.interval[0], Param):
|
||
yield leaf.interval[0]
|
||
if isinstance(leaf.interval[1], Param):
|
||
yield leaf.interval[1]
|
||
elif isinstance(leaf, LinEq):
|
||
if isinstance(leaf.const, Param):
|
||
yield leaf.const
|
||
|
||
return set(fn.mapcat(get_params, self.walk()))
|
||
|
||
def set_params(self, val):
|
||
phi = param_lens(self)
|
||
return phi.modify(lambda x: float(val.get(x, val.get(str(x), x))))
|
||
|
||
@property
|
||
def terms(self):
|
||
return set(terms_lens(self).Each().collect())
|
||
|
||
@property
|
||
def lineqs(self):
|
||
return set(lineq_lens(self).Each().collect())
|
||
|
||
@property
|
||
def atomic_predicates(self):
|
||
return set(AP_lens(self).Each().collect())
|
||
|
||
|
||
class _Top(AST):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
return "⊤"
|
||
|
||
def __invert__(self):
|
||
return BOT
|
||
|
||
|
||
class _Bot(AST):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
return "⊥"
|
||
|
||
def __invert__(self):
|
||
return TOP
|
||
|
||
|
||
TOP = _Top()
|
||
BOT = _Bot()
|
||
|
||
|
||
class AtomicPred(namedtuple("AP", ["id"]), AST):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
return f"{self.id}"
|
||
|
||
@property
|
||
def children(self):
|
||
return set()
|
||
|
||
|
||
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}"
|
||
|
||
@property
|
||
def children(self):
|
||
return set()
|
||
|
||
def __hash__(self):
|
||
# TODO: compute hash based on contents
|
||
return hash(repr(self))
|
||
|
||
|
||
class Var(namedtuple("Var", ["coeff", "id"])):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
if self.coeff == -1:
|
||
coeff_str = "-"
|
||
elif self.coeff == +1:
|
||
coeff_str = ""
|
||
else:
|
||
coeff_str = f"{self.coeff}"
|
||
return f"{coeff_str}{self.id}"
|
||
|
||
|
||
class Interval(namedtuple('I', ['lower', 'upper'])):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
return f"[{self.lower},{self.upper}]"
|
||
|
||
@property
|
||
def children(self):
|
||
return {self.lower, self.upper}
|
||
|
||
|
||
class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
|
||
__slots__ = ()
|
||
|
||
OP = "?"
|
||
|
||
def __repr__(self):
|
||
return f" {self.OP} ".join(f"({x})" for x in self.args)
|
||
|
||
@property
|
||
def children(self):
|
||
return set(self.args)
|
||
|
||
|
||
class Or(NaryOpSTL):
|
||
__slots__ = ()
|
||
|
||
OP = "∨"
|
||
|
||
def __hash__(self):
|
||
# TODO: compute hash based on contents
|
||
return hash(repr(self))
|
||
|
||
|
||
class And(NaryOpSTL):
|
||
__slots__ = ()
|
||
|
||
OP = "∧"
|
||
|
||
def __hash__(self):
|
||
# TODO: compute hash based on contents
|
||
return hash(repr(self))
|
||
|
||
|
||
class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST):
|
||
__slots__ = ()
|
||
OP = '?'
|
||
|
||
def __repr__(self):
|
||
return f"{self.OP}{self.interval}({self.arg})"
|
||
|
||
@property
|
||
def children(self):
|
||
return {self.arg}
|
||
|
||
|
||
class F(ModalOp):
|
||
__slots__ = ()
|
||
OP = "◇"
|
||
|
||
def __hash__(self):
|
||
# TODO: compute hash based on contents
|
||
return hash(repr(self))
|
||
|
||
|
||
class G(ModalOp):
|
||
__slots__ = ()
|
||
OP = "□"
|
||
|
||
def __hash__(self):
|
||
# TODO: compute hash based on contents
|
||
return hash(repr(self))
|
||
|
||
|
||
class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
return f"({self.arg1}) U ({self.arg2})"
|
||
|
||
@property
|
||
def children(self):
|
||
return {self.arg1, self.arg2}
|
||
|
||
def __hash__(self):
|
||
# TODO: compute hash based on contents
|
||
return hash(repr(self))
|
||
|
||
|
||
class Neg(namedtuple('Neg', ['arg']), AST):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
return f"¬({self.arg})"
|
||
|
||
@property
|
||
def children(self):
|
||
return {self.arg}
|
||
|
||
def __hash__(self):
|
||
# TODO: compute hash based on contents
|
||
return hash(repr(self))
|
||
|
||
|
||
class Next(namedtuple('Next', ['arg']), AST):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
return f"X({self.arg})"
|
||
|
||
@property
|
||
def children(self):
|
||
return {self.arg}
|
||
|
||
def __hash__(self):
|
||
# TODO: compute hash based on contents
|
||
return hash(repr(self))
|
||
|
||
|
||
class Param(namedtuple('Param', ['name']), AST):
|
||
__slots__ = ()
|
||
|
||
def __repr__(self):
|
||
return self.name
|
||
|
||
def __hash__(self):
|
||
# TODO: compute hash based on contents
|
||
return hash(repr(self))
|
||
|
||
|
||
def ast_lens(phi, bind=True, *, pred=None, focus_lens=None, getter=False):
|
||
if focus_lens is None:
|
||
|
||
def focus_lens(_):
|
||
return [lens]
|
||
|
||
if pred is None:
|
||
|
||
def pred(_):
|
||
return False
|
||
|
||
child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens)
|
||
phi = lenses.bind(phi) if bind else lens
|
||
return (phi.Tuple if getter else phi.Fork)(*child_lenses)
|
||
|
||
|
||
def _ast_lens(phi, pred, focus_lens):
|
||
if pred(phi):
|
||
yield from focus_lens(phi)
|
||
|
||
if phi is None or not phi.children:
|
||
return
|
||
|
||
if phi is TOP or phi is BOT:
|
||
child_lenses = [lens]
|
||
elif isinstance(phi, Until):
|
||
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
|
||
elif isinstance(phi, NaryOpSTL):
|
||
child_lenses = [
|
||
lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)
|
||
]
|
||
else:
|
||
child_lenses = [lens.GetAttr('arg')]
|
||
for l in child_lenses:
|
||
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
|
||
|
||
|
||
@lru_cache()
|
||
def param_lens(phi, *, getter=False):
|
||
def focus_lens(leaf):
|
||
candidates = [lens.const] if isinstance(leaf, LinEq) else [
|
||
lens.GetAttr('interval')[0],
|
||
lens.GetAttr('interval')[1]
|
||
]
|
||
return (x for x in candidates if isinstance(x.get()(leaf), Param))
|
||
|
||
return ast_lens(
|
||
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
|
||
|
||
|
||
def vars_in_phi(phi):
|
||
focus = terms_lens(phi)
|
||
return set(focus.tuple_(lens.id, lens.time).get_all())
|
||
|
||
|
||
def type_pred(*args):
|
||
ast_types = set(args)
|
||
return lambda x: type(x) in ast_types
|
||
|
||
|
||
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True)
|
||
AP_lens = fn.partial(ast_lens, pred=type_pred(AtomicPred), getter=True)
|
||
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True)
|
||
|
||
|
||
def terms_lens(phi, bind=True):
|
||
return lineq_lens(phi, bind).Each().terms.Each()
|