From d78037816b00d7f10c14719d4a175d7de6cdd110 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Sun, 15 Oct 2017 00:59:58 -0700 Subject: [PATCH] upgrade to new lens version --- requirements.txt | 4 +-- stl/ast.py | 40 ++++++++++++++---------- stl/boolean_eval.py | 2 +- stl/hypothesis.py | 25 ++++++++++++++- stl/utils.py | 76 +++++++++++++++++++++------------------------ 5 files changed, 86 insertions(+), 61 deletions(-) diff --git a/requirements.txt b/requirements.txt index 2fd2e4e..dbc7c12 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ -e git://github.com/mvcisback/hypothesis-cfg@master#egg=hypothesis-cfg bitarray==0.8.1 -funcy==1.7.2 -lenses==0.1.7 +funcy==1.9.1 +lenses==0.3.0 pandas==0.19.2 parsimonious==0.7.0 sympy==1.0 diff --git a/stl/ast.py b/stl/ast.py index 41d9001..4e0d3e4 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -1,5 +1,4 @@ # -*- coding: utf-8 -*- -# TODO: create iso lens between sugar and non-sugar # TODO: supress + given a + (-b). i.e. want a - b from collections import namedtuple, deque @@ -38,8 +37,9 @@ class AST(object): def __invert__(self): return Neg(self) + @property def children(self): - return [] + return set() class _Top(AST): @@ -70,9 +70,10 @@ class AtomicPred(namedtuple("AP", ["id"]), AST): def __repr__(self): return f"{self.id}" - + + @property def children(self): - return [] + return set() class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST): @@ -80,9 +81,10 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST): def __repr__(self): return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}" - + + @property def children(self): - return [] + return set() def __hash__(self): # TODO: compute hash based on contents @@ -101,9 +103,10 @@ class Interval(namedtuple('I', ['lower', 'upper'])): def __repr__(self): return f"[{self.lower},{self.upper}]" - + + @property def children(self): - return [self.lower, self.upper] + return {self.lower, self.upper} class NaryOpSTL(namedtuple('NaryOp', ['args']), AST): @@ -112,9 +115,10 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST): OP = "?" def __repr__(self): return f" {self.OP} ".join(f"({x})" for x in self.args) - + + @property def children(self): - return self.args + return set(self.args) class Or(NaryOpSTL): @@ -141,8 +145,9 @@ class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST): def __repr__(self): return f"{self.OP}{self.interval}({self.arg})" + @property def children(self): - return [self.arg] + return {self.arg} class F(ModalOp): @@ -168,8 +173,9 @@ class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST): def __repr__(self): return f"({self.arg1}) U ({self.arg2})" + @property def children(self): - return [self.arg1, self.arg2] + return {self.arg1, self.arg2} def __hash__(self): # TODO: compute hash based on contents @@ -181,9 +187,10 @@ class Neg(namedtuple('Neg', ['arg']), AST): def __repr__(self): return f"¬({self.arg})" - + + @property def children(self): - return [self.arg] + return {self.arg} def __hash__(self): # TODO: compute hash based on contents @@ -195,9 +202,10 @@ class Next(namedtuple('Next', ['arg']), AST): def __repr__(self): return f"X({self.arg})" - + + @property def children(self): - return [self.arg] + return {self.arg} def __hash__(self): # TODO: compute hash based on contents diff --git a/stl/boolean_eval.py b/stl/boolean_eval.py index d03dfb0..9a3d207 100644 --- a/stl/boolean_eval.py +++ b/stl/boolean_eval.py @@ -13,7 +13,7 @@ import stl oo = float('inf') def pointwise_sat(phi): - ap_names = [z.id.name for z in stl.utils.AP_lens(phi).get_all()] + ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()] def _eval_stl(x, t): evaluated = stl.utils.eval_lineqs(phi, x) evaluated.update(fn.project(x, ap_names)) diff --git a/stl/hypothesis.py b/stl/hypothesis.py index 86c0cb0..fb653da 100644 --- a/stl/hypothesis.py +++ b/stl/hypothesis.py @@ -1,7 +1,7 @@ from hypothesis_cfg import ContextFreeGrammarStrategy +import hypothesis.strategies as st from hypothesis.searchstrategy.strategies import SearchStrategy -from hypothesis.strategies import integers import stl @@ -15,11 +15,32 @@ GRAMMAR = { } + +def build_lineq(params): + pass + + +LinEqStrategy = st.builds( + lambda x: stl.ast.Lineq(*x), + st.tuples( + st.lists( + st.tuples( + st.sampled_from(["x", "y", "z","w"]), + st.integers(min_value=-5, max_value=5)), + min_size=1, max_size=4, unique=True), + st.sampled_from([">=", "<=", "<", ">", "="]), + st.integers(min_value=-5, max_value=5) +)) + + class SignalTemporalLogicStategy(SearchStrategy): def __init__(self, max_length: int): super(SearchStrategy, self).__init__() self.cfg_gen = ContextFreeGrammarStrategy( GRAMMAR, max_length=max_length, start='phi') + self.ap_gen = st.builds( + lambda i: stl.ast.AtomicPred(f"AP{i}"), + st.integers(min_value=0, max_value=max_length)) def do_draw(self, data): # TODO: randomly assign all intervals @@ -27,4 +48,6 @@ class SignalTemporalLogicStategy(SearchStrategy): # TODO: randomly generate boolean predicate # TODO: randomly generate linear predicate phi = stl.parse("".join(self.cfg_gen.do_draw(data))) + ap_lens = stl.utils.AP_lens(phi).Each() + phi = ap_lens.modify(lambda _: self.ap_gen.do_draw(data)) return phi diff --git a/stl/utils.py b/stl/utils.py index 1e4a235..fda3f22 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -1,9 +1,10 @@ -from typing import List, Type, Dict, Mapping, T +from typing import List, Type, Dict, Mapping, T, TypeVar from collections import deque import operator as op from functools import reduce -from lenses import lens, Lens +import lenses +from lenses import lens import funcy as fn import sympy import traces @@ -13,59 +14,52 @@ from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg, AtomicPred) from stl.types import STL, STL_Generator, MTL +Lens = TypeVar('Lens') + def walk(phi:STL) -> STL_Generator: - """DSF walk of the AST.""" + """Walk of the AST.""" pop = deque.pop children = deque([phi]) while len(children) > 0: node = pop(children) yield node - children.extend(node.children()) + children.extend(node.children) def vars_in_phi(phi): focus = stl.terms_lens(phi) - return set(focus.tuple_(lens().id, lens().time).get_all()) + return set(focus.tuple_(lens.id, lens.time).get_all()) def type_pred(*args:List[Type]) -> Mapping[Type, bool]: ast_types = set(args) return lambda x: type(x) in ast_types -def _child_lens(psi:STL, focus:Lens) -> STL_Generator: - if psi is None: - return - elif psi is stl.TOP or psi is stl.BOT: - return - elif isinstance(psi, stl.ast.Until): - yield from [focus.arg1, focus.arg2] - elif isinstance(psi, NaryOpSTL): - for j, _ in enumerate(psi.args): - yield focus.args[j] - else: - yield focus.arg - - -def ast_lens(phi:STL, bind:bool=True, *, - pred:Mapping[T, bool], focus_lens:Lens=None) -> Lens: +def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens: if focus_lens is None: - focus_lens = lambda x: [lens()] - tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens))) - tl = lens().tuple_(*tls).each_() - return tl.bind(phi) if bind else tl + focus_lens = lambda _: [lens] + if pred is None: + pred = lambda _: False + l = lenses.bind(phi) if bind else lens + return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens)) +def _ast_lens(phi:STL, pred, focus_lens) -> Lens: + if pred(phi): + yield from focus_lens(phi) + + if phi is None or not phi.children: + return -def _ast_lens(phi, *, pred, focus=lens(), focus_lens): - psi = focus.get(state=phi) - ret_lens = [focus.add_lens(l) for l in focus_lens(psi)] if pred(psi) else [] - - if isinstance(psi, (LinEq, stl.ast.AtomicPred)): - return ret_lens - - child_lenses = list(_child_lens(psi, focus=focus)) - ret_lens += [_ast_lens(phi, pred=pred, focus=cl, focus_lens=focus_lens) - for cl in child_lenses] - return ret_lens - + if phi is stl.TOP or phi is stl.BOT: + child_lenses = [lens] + elif isinstance(phi, stl.ast.Until): + child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')] + elif isinstance(phi, NaryOpSTL): + child_lenses = [lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)] + else: + child_lenses = [lens.GetAttr('arg')] + for l in child_lenses: + yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)] + lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq)) AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred)) @@ -78,7 +72,7 @@ def terms_lens(phi:STL, bind:bool=True) -> Lens: def param_lens(phi:STL) -> Lens: is_sym = lambda x: isinstance(x, sympy.Symbol) def focus_lens(leaf): - return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] + return [lens.const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens).filter_(is_sym) @@ -93,7 +87,7 @@ def f_neg_or_canonical_form(phi:STL) -> STL: if isinstance(phi, LinEq): return phi - children = [f_neg_or_canonical_form(s) for s in phi.children()] + children = [f_neg_or_canonical_form(s) for s in phi.children] if isinstance(phi, (And, G)): children = [Neg(s) for s in children] children = tuple(children) @@ -160,7 +154,7 @@ def eval_lineq(lineq, x, times=None, compact=True): return float(term.coeff)*x[term.id.name][t] output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1])) - terms = lens(lineq).terms.each_().get_all() + terms = lens(lineq).Each().terms.Each().collect() for t in times: lhs = sum(eval_term(term, t) for term in terms) output[t] = op_lookup[lineq.op](lhs, lineq.const) @@ -172,7 +166,7 @@ def eval_lineq(lineq, x, times=None, compact=True): def eval_lineqs(phi, x, times=None): if times is None: times = get_times(x) - lineqs = set(lineq_lens(phi).get_all()) + lineqs = set(lineq_lens(phi).Each().collect()) return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}