upgrade to new lens version

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-15 00:59:58 -07:00
parent 212a8c195a
commit d78037816b
5 changed files with 86 additions and 61 deletions

View file

@ -1,7 +1,7 @@
-e git://github.com/mvcisback/hypothesis-cfg@master#egg=hypothesis-cfg
bitarray==0.8.1
funcy==1.7.2
lenses==0.1.7
funcy==1.9.1
lenses==0.3.0
pandas==0.19.2
parsimonious==0.7.0
sympy==1.0

View file

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*-
# TODO: create iso lens between sugar and non-sugar
# TODO: supress + given a + (-b). i.e. want a - b
from collections import namedtuple, deque
@ -38,8 +37,9 @@ class AST(object):
def __invert__(self):
return Neg(self)
@property
def children(self):
return []
return set()
class _Top(AST):
@ -70,9 +70,10 @@ class AtomicPred(namedtuple("AP", ["id"]), AST):
def __repr__(self):
return f"{self.id}"
@property
def children(self):
return []
return set()
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
@ -80,9 +81,10 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
def __repr__(self):
return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}"
@property
def children(self):
return []
return set()
def __hash__(self):
# TODO: compute hash based on contents
@ -101,9 +103,10 @@ class Interval(namedtuple('I', ['lower', 'upper'])):
def __repr__(self):
return f"[{self.lower},{self.upper}]"
@property
def children(self):
return [self.lower, self.upper]
return {self.lower, self.upper}
class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
@ -112,9 +115,10 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
OP = "?"
def __repr__(self):
return f" {self.OP} ".join(f"({x})" for x in self.args)
@property
def children(self):
return self.args
return set(self.args)
class Or(NaryOpSTL):
@ -141,8 +145,9 @@ class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST):
def __repr__(self):
return f"{self.OP}{self.interval}({self.arg})"
@property
def children(self):
return [self.arg]
return {self.arg}
class F(ModalOp):
@ -168,8 +173,9 @@ class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST):
def __repr__(self):
return f"({self.arg1}) U ({self.arg2})"
@property
def children(self):
return [self.arg1, self.arg2]
return {self.arg1, self.arg2}
def __hash__(self):
# TODO: compute hash based on contents
@ -181,9 +187,10 @@ class Neg(namedtuple('Neg', ['arg']), AST):
def __repr__(self):
return f"¬({self.arg})"
@property
def children(self):
return [self.arg]
return {self.arg}
def __hash__(self):
# TODO: compute hash based on contents
@ -195,9 +202,10 @@ class Next(namedtuple('Next', ['arg']), AST):
def __repr__(self):
return f"X({self.arg})"
@property
def children(self):
return [self.arg]
return {self.arg}
def __hash__(self):
# TODO: compute hash based on contents

View file

@ -13,7 +13,7 @@ import stl
oo = float('inf')
def pointwise_sat(phi):
ap_names = [z.id.name for z in stl.utils.AP_lens(phi).get_all()]
ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()]
def _eval_stl(x, t):
evaluated = stl.utils.eval_lineqs(phi, x)
evaluated.update(fn.project(x, ap_names))

View file

@ -1,7 +1,7 @@
from hypothesis_cfg import ContextFreeGrammarStrategy
import hypothesis.strategies as st
from hypothesis.searchstrategy.strategies import SearchStrategy
from hypothesis.strategies import integers
import stl
@ -15,11 +15,32 @@ GRAMMAR = {
}
def build_lineq(params):
pass
LinEqStrategy = st.builds(
lambda x: stl.ast.Lineq(*x),
st.tuples(
st.lists(
st.tuples(
st.sampled_from(["x", "y", "z","w"]),
st.integers(min_value=-5, max_value=5)),
min_size=1, max_size=4, unique=True),
st.sampled_from([">=", "<=", "<", ">", "="]),
st.integers(min_value=-5, max_value=5)
))
class SignalTemporalLogicStategy(SearchStrategy):
def __init__(self, max_length: int):
super(SearchStrategy, self).__init__()
self.cfg_gen = ContextFreeGrammarStrategy(
GRAMMAR, max_length=max_length, start='phi')
self.ap_gen = st.builds(
lambda i: stl.ast.AtomicPred(f"AP{i}"),
st.integers(min_value=0, max_value=max_length))
def do_draw(self, data):
# TODO: randomly assign all intervals
@ -27,4 +48,6 @@ class SignalTemporalLogicStategy(SearchStrategy):
# TODO: randomly generate boolean predicate
# TODO: randomly generate linear predicate
phi = stl.parse("".join(self.cfg_gen.do_draw(data)))
ap_lens = stl.utils.AP_lens(phi).Each()
phi = ap_lens.modify(lambda _: self.ap_gen.do_draw(data))
return phi

View file

@ -1,9 +1,10 @@
from typing import List, Type, Dict, Mapping, T
from typing import List, Type, Dict, Mapping, T, TypeVar
from collections import deque
import operator as op
from functools import reduce
from lenses import lens, Lens
import lenses
from lenses import lens
import funcy as fn
import sympy
import traces
@ -13,59 +14,52 @@ from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg,
AtomicPred)
from stl.types import STL, STL_Generator, MTL
Lens = TypeVar('Lens')
def walk(phi:STL) -> STL_Generator:
"""DSF walk of the AST."""
"""Walk of the AST."""
pop = deque.pop
children = deque([phi])
while len(children) > 0:
node = pop(children)
yield node
children.extend(node.children())
children.extend(node.children)
def vars_in_phi(phi):
focus = stl.terms_lens(phi)
return set(focus.tuple_(lens().id, lens().time).get_all())
return set(focus.tuple_(lens.id, lens.time).get_all())
def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
ast_types = set(args)
return lambda x: type(x) in ast_types
def _child_lens(psi:STL, focus:Lens) -> STL_Generator:
if psi is None:
return
elif psi is stl.TOP or psi is stl.BOT:
return
elif isinstance(psi, stl.ast.Until):
yield from [focus.arg1, focus.arg2]
elif isinstance(psi, NaryOpSTL):
for j, _ in enumerate(psi.args):
yield focus.args[j]
else:
yield focus.arg
def ast_lens(phi:STL, bind:bool=True, *,
pred:Mapping[T, bool], focus_lens:Lens=None) -> Lens:
def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
if focus_lens is None:
focus_lens = lambda x: [lens()]
tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens)))
tl = lens().tuple_(*tls).each_()
return tl.bind(phi) if bind else tl
focus_lens = lambda _: [lens]
if pred is None:
pred = lambda _: False
l = lenses.bind(phi) if bind else lens
return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens))
def _ast_lens(phi:STL, pred, focus_lens) -> Lens:
if pred(phi):
yield from focus_lens(phi)
if phi is None or not phi.children:
return
def _ast_lens(phi, *, pred, focus=lens(), focus_lens):
psi = focus.get(state=phi)
ret_lens = [focus.add_lens(l) for l in focus_lens(psi)] if pred(psi) else []
if isinstance(psi, (LinEq, stl.ast.AtomicPred)):
return ret_lens
child_lenses = list(_child_lens(psi, focus=focus))
ret_lens += [_ast_lens(phi, pred=pred, focus=cl, focus_lens=focus_lens)
for cl in child_lenses]
return ret_lens
if phi is stl.TOP or phi is stl.BOT:
child_lenses = [lens]
elif isinstance(phi, stl.ast.Until):
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
elif isinstance(phi, NaryOpSTL):
child_lenses = [lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)]
else:
child_lenses = [lens.GetAttr('arg')]
for l in child_lenses:
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq))
AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred))
@ -78,7 +72,7 @@ def terms_lens(phi:STL, bind:bool=True) -> Lens:
def param_lens(phi:STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf):
return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
return [lens.const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
return ast_lens(phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym)
@ -93,7 +87,7 @@ def f_neg_or_canonical_form(phi:STL) -> STL:
if isinstance(phi, LinEq):
return phi
children = [f_neg_or_canonical_form(s) for s in phi.children()]
children = [f_neg_or_canonical_form(s) for s in phi.children]
if isinstance(phi, (And, G)):
children = [Neg(s) for s in children]
children = tuple(children)
@ -160,7 +154,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
return float(term.coeff)*x[term.id.name][t]
output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1]))
terms = lens(lineq).terms.each_().get_all()
terms = lens(lineq).Each().terms.Each().collect()
for t in times:
lhs = sum(eval_term(term, t) for term in terms)
output[t] = op_lookup[lineq.op](lhs, lineq.const)
@ -172,7 +166,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
def eval_lineqs(phi, x, times=None):
if times is None:
times = get_times(x)
lineqs = set(lineq_lens(phi).get_all())
lineqs = set(lineq_lens(phi).Each().collect())
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}