make ast objects easier to work with
This commit is contained in:
parent
2640728288
commit
7d8cf78222
5 changed files with 131 additions and 128 deletions
|
|
@ -1,7 +1,6 @@
|
|||
# flake8: noqa
|
||||
from stl.utils import terms_lens, lineq_lens, walk, and_or_lens
|
||||
from stl.utils import alw, env, andf, orf
|
||||
from stl.ast import dt_sym, t_sym, TOP, BOT
|
||||
from stl.ast import TOP, BOT
|
||||
from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg,
|
||||
Var, AtomicPred, Until)
|
||||
from stl.parser import parse
|
||||
|
|
|
|||
117
stl/ast.py
117
stl/ast.py
|
|
@ -1,13 +1,12 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# TODO: supress + given a + (-b). i.e. want a - b
|
||||
|
||||
from collections import namedtuple
|
||||
from collections import deque, namedtuple
|
||||
from functools import lru_cache
|
||||
|
||||
import funcy as fn
|
||||
from sympy import Symbol
|
||||
|
||||
dt_sym = Symbol('dt', positive=True)
|
||||
t_sym = Symbol('t', positive=True)
|
||||
import lenses
|
||||
from lenses import lens
|
||||
|
||||
|
||||
def flatten_binary(phi, op, dropT, shortT):
|
||||
|
|
@ -42,6 +41,45 @@ class AST(object):
|
|||
def children(self):
|
||||
return set()
|
||||
|
||||
def walk(self):
|
||||
"""Walk of the AST."""
|
||||
pop = deque.pop
|
||||
children = deque([self])
|
||||
while len(children) > 0:
|
||||
node = pop(children)
|
||||
yield node
|
||||
children.extend(node.children)
|
||||
|
||||
@property
|
||||
def params(self):
|
||||
def get_params(leaf):
|
||||
if isinstance(leaf, ModalOp):
|
||||
if isinstance(leaf.interval[0], Param):
|
||||
yield leaf.interval[0]
|
||||
if isinstance(leaf.interval[1], Param):
|
||||
yield leaf.interval[1]
|
||||
elif isinstance(leaf, LinEq):
|
||||
if isinstance(leaf.const, Param):
|
||||
yield leaf.const
|
||||
|
||||
return set(fn.mapcat(get_params, self.walk()))
|
||||
|
||||
def set_params(self, val):
|
||||
phi = param_lens(self)
|
||||
return phi.modify(lambda x: float(val.get(x, val.get(str(x), x))))
|
||||
|
||||
@property
|
||||
def terms(self):
|
||||
return set(terms_lens(self).Each().collect())
|
||||
|
||||
@property
|
||||
def lineqs(self):
|
||||
return set(lineq_lens(self).Each().collect())
|
||||
|
||||
@property
|
||||
def atomic_predicates(self):
|
||||
return set(AP_lens(self).Each().collect())
|
||||
|
||||
|
||||
class _Top(AST):
|
||||
__slots__ = ()
|
||||
|
|
@ -234,3 +272,72 @@ class Param(namedtuple('Param', ['name']), AST):
|
|||
def __hash__(self):
|
||||
# TODO: compute hash based on contents
|
||||
return hash(repr(self))
|
||||
|
||||
|
||||
def ast_lens(phi, bind=True, *, pred=None, focus_lens=None, getter=False):
|
||||
if focus_lens is None:
|
||||
|
||||
def focus_lens(_):
|
||||
return [lens]
|
||||
|
||||
if pred is None:
|
||||
|
||||
def pred(_):
|
||||
return False
|
||||
|
||||
child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens)
|
||||
phi = lenses.bind(phi) if bind else lens
|
||||
return (phi.Tuple if getter else phi.Fork)(*child_lenses)
|
||||
|
||||
|
||||
def _ast_lens(phi, pred, focus_lens):
|
||||
if pred(phi):
|
||||
yield from focus_lens(phi)
|
||||
|
||||
if phi is None or not phi.children:
|
||||
return
|
||||
|
||||
if phi is TOP or phi is BOT:
|
||||
child_lenses = [lens]
|
||||
elif isinstance(phi, Until):
|
||||
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
|
||||
elif isinstance(phi, NaryOpSTL):
|
||||
child_lenses = [
|
||||
lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)
|
||||
]
|
||||
else:
|
||||
child_lenses = [lens.GetAttr('arg')]
|
||||
for l in child_lenses:
|
||||
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def param_lens(phi, *, getter=False):
|
||||
def focus_lens(leaf):
|
||||
candidates = [lens.const] if isinstance(leaf, LinEq) else [
|
||||
lens.GetAttr('interval')[0],
|
||||
lens.GetAttr('interval')[1]
|
||||
]
|
||||
return (x for x in candidates if isinstance(x.get()(leaf), Param))
|
||||
|
||||
return ast_lens(
|
||||
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
|
||||
|
||||
|
||||
def vars_in_phi(phi):
|
||||
focus = terms_lens(phi)
|
||||
return set(focus.tuple_(lens.id, lens.time).get_all())
|
||||
|
||||
|
||||
def type_pred(*args):
|
||||
ast_types = set(args)
|
||||
return lambda x: type(x) in ast_types
|
||||
|
||||
|
||||
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True)
|
||||
AP_lens = fn.partial(ast_lens, pred=type_pred(AtomicPred), getter=True)
|
||||
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True)
|
||||
|
||||
|
||||
def terms_lens(phi, bind=True):
|
||||
return lineq_lens(phi, bind).Each().terms.Each()
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ oo = float('inf')
|
|||
|
||||
|
||||
def pointwise_sat(phi):
|
||||
ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()]
|
||||
ap_names = [z.id.name for z in stl.ast.AP_lens(phi).Each().collect()]
|
||||
|
||||
def _eval_stl(x, t):
|
||||
evaluated = stl.utils.eval_lineqs(phi, x)
|
||||
|
|
|
|||
|
|
@ -9,3 +9,8 @@ from stl.hypothesis import SignalTemporalLogicStrategy
|
|||
def test_invertable_repr(phi):
|
||||
event(str(phi))
|
||||
assert str(phi) == str(stl.parse(str(phi)))
|
||||
|
||||
|
||||
@given(SignalTemporalLogicStrategy)
|
||||
def test_hash_inheritance(phi):
|
||||
assert hash(repr(phi)) == hash(phi)
|
||||
|
|
|
|||
132
stl/utils.py
132
stl/utils.py
|
|
@ -1,119 +1,12 @@
|
|||
import operator as op
|
||||
from collections import deque
|
||||
from functools import reduce
|
||||
from typing import List, Mapping, Type, TypeVar
|
||||
|
||||
import funcy as fn
|
||||
import traces
|
||||
from lenses import lens, bind
|
||||
|
||||
import lenses
|
||||
import stl.ast
|
||||
from lenses import lens
|
||||
from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL, Neg, Or,
|
||||
Param, ModalOp)
|
||||
from stl.types import STL, STL_Generator
|
||||
|
||||
Lens = TypeVar('Lens')
|
||||
|
||||
|
||||
def walk(phi: STL) -> STL_Generator:
|
||||
"""Walk of the AST."""
|
||||
pop = deque.pop
|
||||
children = deque([phi])
|
||||
while len(children) > 0:
|
||||
node = pop(children)
|
||||
yield node
|
||||
children.extend(node.children)
|
||||
|
||||
|
||||
def list_params(phi: STL):
|
||||
"""Walk of the AST."""
|
||||
|
||||
def get_params(leaf):
|
||||
if isinstance(leaf, ModalOp):
|
||||
if isinstance(leaf.interval[0], Param):
|
||||
yield leaf.interval[0]
|
||||
if isinstance(leaf.interval[1], Param):
|
||||
yield leaf.interval[1]
|
||||
elif isinstance(leaf, LinEq):
|
||||
if isinstance(leaf.const, Param):
|
||||
yield leaf.const
|
||||
|
||||
return set(fn.mapcat(get_params, walk(phi)))
|
||||
|
||||
|
||||
def vars_in_phi(phi):
|
||||
focus = stl.terms_lens(phi)
|
||||
return set(focus.tuple_(lens.id, lens.time).get_all())
|
||||
|
||||
|
||||
def type_pred(*args: List[Type]) -> Mapping[Type, bool]:
|
||||
ast_types = set(args)
|
||||
return lambda x: type(x) in ast_types
|
||||
|
||||
|
||||
def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None,
|
||||
getter=False) -> Lens:
|
||||
if focus_lens is None:
|
||||
|
||||
def focus_lens(_):
|
||||
return [lens]
|
||||
|
||||
if pred is None:
|
||||
|
||||
def pred(_):
|
||||
return False
|
||||
|
||||
child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens)
|
||||
phi = lenses.bind(phi) if bind else lens
|
||||
return (phi.Tuple if getter else phi.Fork)(*child_lenses)
|
||||
|
||||
|
||||
def _ast_lens(phi: STL, pred, focus_lens) -> Lens:
|
||||
if pred(phi):
|
||||
yield from focus_lens(phi)
|
||||
|
||||
if phi is None or not phi.children:
|
||||
return
|
||||
|
||||
if phi is stl.TOP or phi is stl.BOT:
|
||||
child_lenses = [lens]
|
||||
elif isinstance(phi, stl.ast.Until):
|
||||
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
|
||||
elif isinstance(phi, NaryOpSTL):
|
||||
child_lenses = [
|
||||
lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)
|
||||
]
|
||||
else:
|
||||
child_lenses = [lens.GetAttr('arg')]
|
||||
for l in child_lenses:
|
||||
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
|
||||
|
||||
|
||||
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True)
|
||||
AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred), getter=True)
|
||||
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True)
|
||||
|
||||
|
||||
def terms_lens(phi: STL, bind: bool = True) -> Lens:
|
||||
return lineq_lens(phi, bind).Each().terms.Each()
|
||||
|
||||
|
||||
def param_lens(phi: STL, *, getter=False) -> Lens:
|
||||
def focus_lens(leaf):
|
||||
candidates = [lens.const] if isinstance(leaf, LinEq) else [
|
||||
lens.GetAttr('interval')[0],
|
||||
lens.GetAttr('interval')[1]
|
||||
]
|
||||
return (x for x in candidates if isinstance(x.get()(leaf), Param))
|
||||
|
||||
return ast_lens(
|
||||
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
|
||||
|
||||
|
||||
def set_params(phi, val) -> STL:
|
||||
phi = param_lens(phi) if isinstance(phi, AST) else phi
|
||||
return phi.modify(lambda x: float(val.get(x, val.get(str(x), x))))
|
||||
from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens)
|
||||
from stl.types import STL
|
||||
|
||||
|
||||
def f_neg_or_canonical_form(phi: STL) -> STL:
|
||||
|
|
@ -140,12 +33,12 @@ def f_neg_or_canonical_form(phi: STL) -> STL:
|
|||
|
||||
|
||||
def _lineq_lipschitz(lineq):
|
||||
return sum(map(abs, lens(lineq).Each().terms.Each().coeff.collect()))
|
||||
return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect()))
|
||||
|
||||
|
||||
def linear_stl_lipschitz(phi):
|
||||
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
||||
return float(max(map(_lineq_lipschitz, lineq_lens(phi).Each().collect())))
|
||||
return float(max(map(_lineq_lipschitz, phi.lineqs)))
|
||||
|
||||
|
||||
def inline_context(phi, context):
|
||||
|
|
@ -174,18 +67,17 @@ def get_times(x):
|
|||
return sorted(times)
|
||||
|
||||
|
||||
def eval_lineq(lineq, x, times=None, compact=True):
|
||||
if times is None:
|
||||
times = get_times(x)
|
||||
|
||||
def eval_lineq(lineq, x, compact=True):
|
||||
def eval_term(term, t):
|
||||
return float(term.coeff) * x[term.id.name][t]
|
||||
|
||||
output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1]))
|
||||
terms = lens(lineq).Each().terms.Each().collect()
|
||||
for t in times:
|
||||
|
||||
def f(t):
|
||||
lhs = sum(eval_term(term, t) for term in terms)
|
||||
output[t] = op_lookup[lineq.op](lhs, lineq.const)
|
||||
return op_lookup[lineq.op](lhs, lineq.const)
|
||||
|
||||
output = traces.TimeSeries(map(f, x), domain=x.domain)
|
||||
|
||||
if compact:
|
||||
output.compact()
|
||||
|
|
@ -195,7 +87,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
|
|||
def eval_lineqs(phi, x, times=None):
|
||||
if times is None:
|
||||
times = get_times(x)
|
||||
lineqs = set(lineq_lens(phi).Each().collect())
|
||||
lineqs = phi.lineqs
|
||||
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue