diff --git a/stl/__init__.py b/stl/__init__.py index 12f9229..61bd2ca 100644 --- a/stl/__init__.py +++ b/stl/__init__.py @@ -1,7 +1,6 @@ # flake8: noqa -from stl.utils import terms_lens, lineq_lens, walk, and_or_lens from stl.utils import alw, env, andf, orf -from stl.ast import dt_sym, t_sym, TOP, BOT +from stl.ast import TOP, BOT from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred, Until) from stl.parser import parse diff --git a/stl/ast.py b/stl/ast.py index 68da66c..7054827 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -1,13 +1,12 @@ # -*- coding: utf-8 -*- # TODO: supress + given a + (-b). i.e. want a - b -from collections import namedtuple +from collections import deque, namedtuple +from functools import lru_cache import funcy as fn -from sympy import Symbol - -dt_sym = Symbol('dt', positive=True) -t_sym = Symbol('t', positive=True) +import lenses +from lenses import lens def flatten_binary(phi, op, dropT, shortT): @@ -42,6 +41,45 @@ class AST(object): def children(self): return set() + def walk(self): + """Walk of the AST.""" + pop = deque.pop + children = deque([self]) + while len(children) > 0: + node = pop(children) + yield node + children.extend(node.children) + + @property + def params(self): + def get_params(leaf): + if isinstance(leaf, ModalOp): + if isinstance(leaf.interval[0], Param): + yield leaf.interval[0] + if isinstance(leaf.interval[1], Param): + yield leaf.interval[1] + elif isinstance(leaf, LinEq): + if isinstance(leaf.const, Param): + yield leaf.const + + return set(fn.mapcat(get_params, self.walk())) + + def set_params(self, val): + phi = param_lens(self) + return phi.modify(lambda x: float(val.get(x, val.get(str(x), x)))) + + @property + def terms(self): + return set(terms_lens(self).Each().collect()) + + @property + def lineqs(self): + return set(lineq_lens(self).Each().collect()) + + @property + def atomic_predicates(self): + return set(AP_lens(self).Each().collect()) + class _Top(AST): __slots__ = () @@ -234,3 +272,72 @@ class Param(namedtuple('Param', ['name']), AST): def __hash__(self): # TODO: compute hash based on contents return hash(repr(self)) + + +def ast_lens(phi, bind=True, *, pred=None, focus_lens=None, getter=False): + if focus_lens is None: + + def focus_lens(_): + return [lens] + + if pred is None: + + def pred(_): + return False + + child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens) + phi = lenses.bind(phi) if bind else lens + return (phi.Tuple if getter else phi.Fork)(*child_lenses) + + +def _ast_lens(phi, pred, focus_lens): + if pred(phi): + yield from focus_lens(phi) + + if phi is None or not phi.children: + return + + if phi is TOP or phi is BOT: + child_lenses = [lens] + elif isinstance(phi, Until): + child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')] + elif isinstance(phi, NaryOpSTL): + child_lenses = [ + lens.GetAttr('args')[j] for j, _ in enumerate(phi.args) + ] + else: + child_lenses = [lens.GetAttr('arg')] + for l in child_lenses: + yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)] + + +@lru_cache() +def param_lens(phi, *, getter=False): + def focus_lens(leaf): + candidates = [lens.const] if isinstance(leaf, LinEq) else [ + lens.GetAttr('interval')[0], + lens.GetAttr('interval')[1] + ] + return (x for x in candidates if isinstance(x.get()(leaf), Param)) + + return ast_lens( + phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter) + + +def vars_in_phi(phi): + focus = terms_lens(phi) + return set(focus.tuple_(lens.id, lens.time).get_all()) + + +def type_pred(*args): + ast_types = set(args) + return lambda x: type(x) in ast_types + + +lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True) +AP_lens = fn.partial(ast_lens, pred=type_pred(AtomicPred), getter=True) +and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True) + + +def terms_lens(phi, bind=True): + return lineq_lens(phi, bind).Each().terms.Each() diff --git a/stl/boolean_eval.py b/stl/boolean_eval.py index 0ea707b..4d56522 100644 --- a/stl/boolean_eval.py +++ b/stl/boolean_eval.py @@ -13,7 +13,7 @@ oo = float('inf') def pointwise_sat(phi): - ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()] + ap_names = [z.id.name for z in stl.ast.AP_lens(phi).Each().collect()] def _eval_stl(x, t): evaluated = stl.utils.eval_lineqs(phi, x) diff --git a/stl/test_parser.py b/stl/test_parser.py index b66c317..0ce2d4c 100644 --- a/stl/test_parser.py +++ b/stl/test_parser.py @@ -9,3 +9,8 @@ from stl.hypothesis import SignalTemporalLogicStrategy def test_invertable_repr(phi): event(str(phi)) assert str(phi) == str(stl.parse(str(phi))) + + +@given(SignalTemporalLogicStrategy) +def test_hash_inheritance(phi): + assert hash(repr(phi)) == hash(phi) diff --git a/stl/utils.py b/stl/utils.py index fd9e863..2ff29cd 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -1,119 +1,12 @@ import operator as op -from collections import deque from functools import reduce -from typing import List, Mapping, Type, TypeVar -import funcy as fn import traces +from lenses import lens, bind -import lenses import stl.ast -from lenses import lens -from stl.ast import (AST, And, F, G, Interval, LinEq, NaryOpSTL, Neg, Or, - Param, ModalOp) -from stl.types import STL, STL_Generator - -Lens = TypeVar('Lens') - - -def walk(phi: STL) -> STL_Generator: - """Walk of the AST.""" - pop = deque.pop - children = deque([phi]) - while len(children) > 0: - node = pop(children) - yield node - children.extend(node.children) - - -def list_params(phi: STL): - """Walk of the AST.""" - - def get_params(leaf): - if isinstance(leaf, ModalOp): - if isinstance(leaf.interval[0], Param): - yield leaf.interval[0] - if isinstance(leaf.interval[1], Param): - yield leaf.interval[1] - elif isinstance(leaf, LinEq): - if isinstance(leaf.const, Param): - yield leaf.const - - return set(fn.mapcat(get_params, walk(phi))) - - -def vars_in_phi(phi): - focus = stl.terms_lens(phi) - return set(focus.tuple_(lens.id, lens.time).get_all()) - - -def type_pred(*args: List[Type]) -> Mapping[Type, bool]: - ast_types = set(args) - return lambda x: type(x) in ast_types - - -def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None, - getter=False) -> Lens: - if focus_lens is None: - - def focus_lens(_): - return [lens] - - if pred is None: - - def pred(_): - return False - - child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens) - phi = lenses.bind(phi) if bind else lens - return (phi.Tuple if getter else phi.Fork)(*child_lenses) - - -def _ast_lens(phi: STL, pred, focus_lens) -> Lens: - if pred(phi): - yield from focus_lens(phi) - - if phi is None or not phi.children: - return - - if phi is stl.TOP or phi is stl.BOT: - child_lenses = [lens] - elif isinstance(phi, stl.ast.Until): - child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')] - elif isinstance(phi, NaryOpSTL): - child_lenses = [ - lens.GetAttr('args')[j] for j, _ in enumerate(phi.args) - ] - else: - child_lenses = [lens.GetAttr('arg')] - for l in child_lenses: - yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)] - - -lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True) -AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred), getter=True) -and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True) - - -def terms_lens(phi: STL, bind: bool = True) -> Lens: - return lineq_lens(phi, bind).Each().terms.Each() - - -def param_lens(phi: STL, *, getter=False) -> Lens: - def focus_lens(leaf): - candidates = [lens.const] if isinstance(leaf, LinEq) else [ - lens.GetAttr('interval')[0], - lens.GetAttr('interval')[1] - ] - return (x for x in candidates if isinstance(x.get()(leaf), Param)) - - return ast_lens( - phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter) - - -def set_params(phi, val) -> STL: - phi = param_lens(phi) if isinstance(phi, AST) else phi - return phi.modify(lambda x: float(val.get(x, val.get(str(x), x)))) +from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens) +from stl.types import STL def f_neg_or_canonical_form(phi: STL) -> STL: @@ -140,12 +33,12 @@ def f_neg_or_canonical_form(phi: STL) -> STL: def _lineq_lipschitz(lineq): - return sum(map(abs, lens(lineq).Each().terms.Each().coeff.collect())) + return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect())) def linear_stl_lipschitz(phi): """Infinity norm lipschitz bound of linear inequality predicate.""" - return float(max(map(_lineq_lipschitz, lineq_lens(phi).Each().collect()))) + return float(max(map(_lineq_lipschitz, phi.lineqs))) def inline_context(phi, context): @@ -174,18 +67,17 @@ def get_times(x): return sorted(times) -def eval_lineq(lineq, x, times=None, compact=True): - if times is None: - times = get_times(x) - +def eval_lineq(lineq, x, compact=True): def eval_term(term, t): return float(term.coeff) * x[term.id.name][t] - output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1])) terms = lens(lineq).Each().terms.Each().collect() - for t in times: + + def f(t): lhs = sum(eval_term(term, t) for term in terms) - output[t] = op_lookup[lineq.op](lhs, lineq.const) + return op_lookup[lineq.op](lhs, lineq.const) + + output = traces.TimeSeries(map(f, x), domain=x.domain) if compact: output.compact() @@ -195,7 +87,7 @@ def eval_lineq(lineq, x, times=None, compact=True): def eval_lineqs(phi, x, times=None): if times is None: times = get_times(x) - lineqs = set(lineq_lens(phi).Each().collect()) + lineqs = phi.lineqs return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}