make ast objects easier to work with

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-31 18:29:31 -07:00
parent 2640728288
commit 7d8cf78222
5 changed files with 131 additions and 128 deletions

View file

@ -1,7 +1,6 @@
# flake8: noqa # flake8: noqa
from stl.utils import terms_lens, lineq_lens, walk, and_or_lens
from stl.utils import alw, env, andf, orf from stl.utils import alw, env, andf, orf
from stl.ast import dt_sym, t_sym, TOP, BOT from stl.ast import TOP, BOT
from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg,
Var, AtomicPred, Until) Var, AtomicPred, Until)
from stl.parser import parse from stl.parser import parse

View file

@ -1,13 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# TODO: supress + given a + (-b). i.e. want a - b # TODO: supress + given a + (-b). i.e. want a - b
from collections import namedtuple from collections import deque, namedtuple
from functools import lru_cache
import funcy as fn import funcy as fn
from sympy import Symbol import lenses
from lenses import lens
dt_sym = Symbol('dt', positive=True)
t_sym = Symbol('t', positive=True)
def flatten_binary(phi, op, dropT, shortT): def flatten_binary(phi, op, dropT, shortT):
@ -42,6 +41,45 @@ class AST(object):
def children(self): def children(self):
return set() return set()
def walk(self):
"""Walk of the AST."""
pop = deque.pop
children = deque([self])
while len(children) > 0:
node = pop(children)
yield node
children.extend(node.children)
@property
def params(self):
def get_params(leaf):
if isinstance(leaf, ModalOp):
if isinstance(leaf.interval[0], Param):
yield leaf.interval[0]
if isinstance(leaf.interval[1], Param):
yield leaf.interval[1]
elif isinstance(leaf, LinEq):
if isinstance(leaf.const, Param):
yield leaf.const
return set(fn.mapcat(get_params, self.walk()))
def set_params(self, val):
phi = param_lens(self)
return phi.modify(lambda x: float(val.get(x, val.get(str(x), x))))
@property
def terms(self):
return set(terms_lens(self).Each().collect())
@property
def lineqs(self):
return set(lineq_lens(self).Each().collect())
@property
def atomic_predicates(self):
return set(AP_lens(self).Each().collect())
class _Top(AST): class _Top(AST):
__slots__ = () __slots__ = ()
@ -234,3 +272,72 @@ class Param(namedtuple('Param', ['name']), AST):
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
return hash(repr(self)) return hash(repr(self))
def ast_lens(phi, bind=True, *, pred=None, focus_lens=None, getter=False):
if focus_lens is None:
def focus_lens(_):
return [lens]
if pred is None:
def pred(_):
return False
child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens)
phi = lenses.bind(phi) if bind else lens
return (phi.Tuple if getter else phi.Fork)(*child_lenses)
def _ast_lens(phi, pred, focus_lens):
if pred(phi):
yield from focus_lens(phi)
if phi is None or not phi.children:
return
if phi is TOP or phi is BOT:
child_lenses = [lens]
elif isinstance(phi, Until):
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
elif isinstance(phi, NaryOpSTL):
child_lenses = [
lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)
]
else:
child_lenses = [lens.GetAttr('arg')]
for l in child_lenses:
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
@lru_cache()
def param_lens(phi, *, getter=False):
def focus_lens(leaf):
candidates = [lens.const] if isinstance(leaf, LinEq) else [
lens.GetAttr('interval')[0],
lens.GetAttr('interval')[1]
]
return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens(
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
def vars_in_phi(phi):
focus = terms_lens(phi)
return set(focus.tuple_(lens.id, lens.time).get_all())
def type_pred(*args):
ast_types = set(args)
return lambda x: type(x) in ast_types
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True)
AP_lens = fn.partial(ast_lens, pred=type_pred(AtomicPred), getter=True)
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True)
def terms_lens(phi, bind=True):
return lineq_lens(phi, bind).Each().terms.Each()

View file

@ -13,7 +13,7 @@ oo = float('inf')
def pointwise_sat(phi): def pointwise_sat(phi):
ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()] ap_names = [z.id.name for z in stl.ast.AP_lens(phi).Each().collect()]
def _eval_stl(x, t): def _eval_stl(x, t):
evaluated = stl.utils.eval_lineqs(phi, x) evaluated = stl.utils.eval_lineqs(phi, x)

View file

@ -9,3 +9,8 @@ from stl.hypothesis import SignalTemporalLogicStrategy
def test_invertable_repr(phi): def test_invertable_repr(phi):
event(str(phi)) event(str(phi))
assert str(phi) == str(stl.parse(str(phi))) assert str(phi) == str(stl.parse(str(phi)))
@given(SignalTemporalLogicStrategy)
def test_hash_inheritance(phi):
assert hash(repr(phi)) == hash(phi)

View file

@ -1,119 +1,12 @@
import operator as op import operator as op
from collections import deque
from functools import reduce from functools import reduce
from typing import List, Mapping, Type, TypeVar
import funcy as fn
import traces import traces
from lenses import lens, bind
import lenses
import stl.ast import stl.ast
from lenses import lens from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens)
from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL, Neg, Or, from stl.types import STL
Param, ModalOp)
from stl.types import STL, STL_Generator
Lens = TypeVar('Lens')
def walk(phi: STL) -> STL_Generator:
"""Walk of the AST."""
pop = deque.pop
children = deque([phi])
while len(children) > 0:
node = pop(children)
yield node
children.extend(node.children)
def list_params(phi: STL):
"""Walk of the AST."""
def get_params(leaf):
if isinstance(leaf, ModalOp):
if isinstance(leaf.interval[0], Param):
yield leaf.interval[0]
if isinstance(leaf.interval[1], Param):
yield leaf.interval[1]
elif isinstance(leaf, LinEq):
if isinstance(leaf.const, Param):
yield leaf.const
return set(fn.mapcat(get_params, walk(phi)))
def vars_in_phi(phi):
focus = stl.terms_lens(phi)
return set(focus.tuple_(lens.id, lens.time).get_all())
def type_pred(*args: List[Type]) -> Mapping[Type, bool]:
ast_types = set(args)
return lambda x: type(x) in ast_types
def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None,
getter=False) -> Lens:
if focus_lens is None:
def focus_lens(_):
return [lens]
if pred is None:
def pred(_):
return False
child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens)
phi = lenses.bind(phi) if bind else lens
return (phi.Tuple if getter else phi.Fork)(*child_lenses)
def _ast_lens(phi: STL, pred, focus_lens) -> Lens:
if pred(phi):
yield from focus_lens(phi)
if phi is None or not phi.children:
return
if phi is stl.TOP or phi is stl.BOT:
child_lenses = [lens]
elif isinstance(phi, stl.ast.Until):
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
elif isinstance(phi, NaryOpSTL):
child_lenses = [
lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)
]
else:
child_lenses = [lens.GetAttr('arg')]
for l in child_lenses:
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True)
AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred), getter=True)
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True)
def terms_lens(phi: STL, bind: bool = True) -> Lens:
return lineq_lens(phi, bind).Each().terms.Each()
def param_lens(phi: STL, *, getter=False) -> Lens:
def focus_lens(leaf):
candidates = [lens.const] if isinstance(leaf, LinEq) else [
lens.GetAttr('interval')[0],
lens.GetAttr('interval')[1]
]
return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens(
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
def set_params(phi, val) -> STL:
phi = param_lens(phi) if isinstance(phi, AST) else phi
return phi.modify(lambda x: float(val.get(x, val.get(str(x), x))))
def f_neg_or_canonical_form(phi: STL) -> STL: def f_neg_or_canonical_form(phi: STL) -> STL:
@ -140,12 +33,12 @@ def f_neg_or_canonical_form(phi: STL) -> STL:
def _lineq_lipschitz(lineq): def _lineq_lipschitz(lineq):
return sum(map(abs, lens(lineq).Each().terms.Each().coeff.collect())) return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect()))
def linear_stl_lipschitz(phi): def linear_stl_lipschitz(phi):
"""Infinity norm lipschitz bound of linear inequality predicate.""" """Infinity norm lipschitz bound of linear inequality predicate."""
return float(max(map(_lineq_lipschitz, lineq_lens(phi).Each().collect()))) return float(max(map(_lineq_lipschitz, phi.lineqs)))
def inline_context(phi, context): def inline_context(phi, context):
@ -174,18 +67,17 @@ def get_times(x):
return sorted(times) return sorted(times)
def eval_lineq(lineq, x, times=None, compact=True): def eval_lineq(lineq, x, compact=True):
if times is None:
times = get_times(x)
def eval_term(term, t): def eval_term(term, t):
return float(term.coeff) * x[term.id.name][t] return float(term.coeff) * x[term.id.name][t]
output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1]))
terms = lens(lineq).Each().terms.Each().collect() terms = lens(lineq).Each().terms.Each().collect()
for t in times:
def f(t):
lhs = sum(eval_term(term, t) for term in terms) lhs = sum(eval_term(term, t) for term in terms)
output[t] = op_lookup[lineq.op](lhs, lineq.const) return op_lookup[lineq.op](lhs, lineq.const)
output = traces.TimeSeries(map(f, x), domain=x.domain)
if compact: if compact:
output.compact() output.compact()
@ -195,7 +87,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
def eval_lineqs(phi, x, times=None): def eval_lineqs(phi, x, times=None):
if times is None: if times is None:
times = get_times(x) times = get_times(x)
lineqs = set(lineq_lens(phi).Each().collect()) lineqs = phi.lineqs
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs} return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}