upgrade to new lens version

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-15 00:59:58 -07:00
parent 212a8c195a
commit d78037816b
5 changed files with 86 additions and 61 deletions

View file

@ -1,7 +1,7 @@
-e git://github.com/mvcisback/hypothesis-cfg@master#egg=hypothesis-cfg -e git://github.com/mvcisback/hypothesis-cfg@master#egg=hypothesis-cfg
bitarray==0.8.1 bitarray==0.8.1
funcy==1.7.2 funcy==1.9.1
lenses==0.1.7 lenses==0.3.0
pandas==0.19.2 pandas==0.19.2
parsimonious==0.7.0 parsimonious==0.7.0
sympy==1.0 sympy==1.0

View file

@ -1,5 +1,4 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# TODO: create iso lens between sugar and non-sugar
# TODO: supress + given a + (-b). i.e. want a - b # TODO: supress + given a + (-b). i.e. want a - b
from collections import namedtuple, deque from collections import namedtuple, deque
@ -38,8 +37,9 @@ class AST(object):
def __invert__(self): def __invert__(self):
return Neg(self) return Neg(self)
@property
def children(self): def children(self):
return [] return set()
class _Top(AST): class _Top(AST):
@ -71,8 +71,9 @@ class AtomicPred(namedtuple("AP", ["id"]), AST):
def __repr__(self): def __repr__(self):
return f"{self.id}" return f"{self.id}"
@property
def children(self): def children(self):
return [] return set()
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST): class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
@ -81,8 +82,9 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
def __repr__(self): def __repr__(self):
return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}" return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}"
@property
def children(self): def children(self):
return [] return set()
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
@ -102,8 +104,9 @@ class Interval(namedtuple('I', ['lower', 'upper'])):
def __repr__(self): def __repr__(self):
return f"[{self.lower},{self.upper}]" return f"[{self.lower},{self.upper}]"
@property
def children(self): def children(self):
return [self.lower, self.upper] return {self.lower, self.upper}
class NaryOpSTL(namedtuple('NaryOp', ['args']), AST): class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
@ -113,8 +116,9 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
def __repr__(self): def __repr__(self):
return f" {self.OP} ".join(f"({x})" for x in self.args) return f" {self.OP} ".join(f"({x})" for x in self.args)
@property
def children(self): def children(self):
return self.args return set(self.args)
class Or(NaryOpSTL): class Or(NaryOpSTL):
@ -141,8 +145,9 @@ class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST):
def __repr__(self): def __repr__(self):
return f"{self.OP}{self.interval}({self.arg})" return f"{self.OP}{self.interval}({self.arg})"
@property
def children(self): def children(self):
return [self.arg] return {self.arg}
class F(ModalOp): class F(ModalOp):
@ -168,8 +173,9 @@ class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST):
def __repr__(self): def __repr__(self):
return f"({self.arg1}) U ({self.arg2})" return f"({self.arg1}) U ({self.arg2})"
@property
def children(self): def children(self):
return [self.arg1, self.arg2] return {self.arg1, self.arg2}
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
@ -182,8 +188,9 @@ class Neg(namedtuple('Neg', ['arg']), AST):
def __repr__(self): def __repr__(self):
return f"¬({self.arg})" return f"¬({self.arg})"
@property
def children(self): def children(self):
return [self.arg] return {self.arg}
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
@ -196,8 +203,9 @@ class Next(namedtuple('Next', ['arg']), AST):
def __repr__(self): def __repr__(self):
return f"X({self.arg})" return f"X({self.arg})"
@property
def children(self): def children(self):
return [self.arg] return {self.arg}
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents

View file

@ -13,7 +13,7 @@ import stl
oo = float('inf') oo = float('inf')
def pointwise_sat(phi): def pointwise_sat(phi):
ap_names = [z.id.name for z in stl.utils.AP_lens(phi).get_all()] ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()]
def _eval_stl(x, t): def _eval_stl(x, t):
evaluated = stl.utils.eval_lineqs(phi, x) evaluated = stl.utils.eval_lineqs(phi, x)
evaluated.update(fn.project(x, ap_names)) evaluated.update(fn.project(x, ap_names))

View file

@ -1,7 +1,7 @@
from hypothesis_cfg import ContextFreeGrammarStrategy from hypothesis_cfg import ContextFreeGrammarStrategy
import hypothesis.strategies as st
from hypothesis.searchstrategy.strategies import SearchStrategy from hypothesis.searchstrategy.strategies import SearchStrategy
from hypothesis.strategies import integers
import stl import stl
@ -15,11 +15,32 @@ GRAMMAR = {
} }
def build_lineq(params):
pass
LinEqStrategy = st.builds(
lambda x: stl.ast.Lineq(*x),
st.tuples(
st.lists(
st.tuples(
st.sampled_from(["x", "y", "z","w"]),
st.integers(min_value=-5, max_value=5)),
min_size=1, max_size=4, unique=True),
st.sampled_from([">=", "<=", "<", ">", "="]),
st.integers(min_value=-5, max_value=5)
))
class SignalTemporalLogicStategy(SearchStrategy): class SignalTemporalLogicStategy(SearchStrategy):
def __init__(self, max_length: int): def __init__(self, max_length: int):
super(SearchStrategy, self).__init__() super(SearchStrategy, self).__init__()
self.cfg_gen = ContextFreeGrammarStrategy( self.cfg_gen = ContextFreeGrammarStrategy(
GRAMMAR, max_length=max_length, start='phi') GRAMMAR, max_length=max_length, start='phi')
self.ap_gen = st.builds(
lambda i: stl.ast.AtomicPred(f"AP{i}"),
st.integers(min_value=0, max_value=max_length))
def do_draw(self, data): def do_draw(self, data):
# TODO: randomly assign all intervals # TODO: randomly assign all intervals
@ -27,4 +48,6 @@ class SignalTemporalLogicStategy(SearchStrategy):
# TODO: randomly generate boolean predicate # TODO: randomly generate boolean predicate
# TODO: randomly generate linear predicate # TODO: randomly generate linear predicate
phi = stl.parse("".join(self.cfg_gen.do_draw(data))) phi = stl.parse("".join(self.cfg_gen.do_draw(data)))
ap_lens = stl.utils.AP_lens(phi).Each()
phi = ap_lens.modify(lambda _: self.ap_gen.do_draw(data))
return phi return phi

View file

@ -1,9 +1,10 @@
from typing import List, Type, Dict, Mapping, T from typing import List, Type, Dict, Mapping, T, TypeVar
from collections import deque from collections import deque
import operator as op import operator as op
from functools import reduce from functools import reduce
from lenses import lens, Lens import lenses
from lenses import lens
import funcy as fn import funcy as fn
import sympy import sympy
import traces import traces
@ -13,58 +14,51 @@ from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg,
AtomicPred) AtomicPred)
from stl.types import STL, STL_Generator, MTL from stl.types import STL, STL_Generator, MTL
Lens = TypeVar('Lens')
def walk(phi:STL) -> STL_Generator: def walk(phi:STL) -> STL_Generator:
"""DSF walk of the AST.""" """Walk of the AST."""
pop = deque.pop pop = deque.pop
children = deque([phi]) children = deque([phi])
while len(children) > 0: while len(children) > 0:
node = pop(children) node = pop(children)
yield node yield node
children.extend(node.children()) children.extend(node.children)
def vars_in_phi(phi): def vars_in_phi(phi):
focus = stl.terms_lens(phi) focus = stl.terms_lens(phi)
return set(focus.tuple_(lens().id, lens().time).get_all()) return set(focus.tuple_(lens.id, lens.time).get_all())
def type_pred(*args:List[Type]) -> Mapping[Type, bool]: def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
ast_types = set(args) ast_types = set(args)
return lambda x: type(x) in ast_types return lambda x: type(x) in ast_types
def _child_lens(psi:STL, focus:Lens) -> STL_Generator: def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
if psi is None:
return
elif psi is stl.TOP or psi is stl.BOT:
return
elif isinstance(psi, stl.ast.Until):
yield from [focus.arg1, focus.arg2]
elif isinstance(psi, NaryOpSTL):
for j, _ in enumerate(psi.args):
yield focus.args[j]
else:
yield focus.arg
def ast_lens(phi:STL, bind:bool=True, *,
pred:Mapping[T, bool], focus_lens:Lens=None) -> Lens:
if focus_lens is None: if focus_lens is None:
focus_lens = lambda x: [lens()] focus_lens = lambda _: [lens]
tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens))) if pred is None:
tl = lens().tuple_(*tls).each_() pred = lambda _: False
return tl.bind(phi) if bind else tl l = lenses.bind(phi) if bind else lens
return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens))
def _ast_lens(phi:STL, pred, focus_lens) -> Lens:
if pred(phi):
yield from focus_lens(phi)
def _ast_lens(phi, *, pred, focus=lens(), focus_lens): if phi is None or not phi.children:
psi = focus.get(state=phi) return
ret_lens = [focus.add_lens(l) for l in focus_lens(psi)] if pred(psi) else []
if isinstance(psi, (LinEq, stl.ast.AtomicPred)): if phi is stl.TOP or phi is stl.BOT:
return ret_lens child_lenses = [lens]
elif isinstance(phi, stl.ast.Until):
child_lenses = list(_child_lens(psi, focus=focus)) child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
ret_lens += [_ast_lens(phi, pred=pred, focus=cl, focus_lens=focus_lens) elif isinstance(phi, NaryOpSTL):
for cl in child_lenses] child_lenses = [lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)]
return ret_lens else:
child_lenses = [lens.GetAttr('arg')]
for l in child_lenses:
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq)) lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq))
@ -78,7 +72,7 @@ def terms_lens(phi:STL, bind:bool=True) -> Lens:
def param_lens(phi:STL) -> Lens: def param_lens(phi:STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol) is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf): def focus_lens(leaf):
return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] return [lens.const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
return ast_lens(phi, pred=type_pred(LinEq, F, G), return ast_lens(phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym) focus_lens=focus_lens).filter_(is_sym)
@ -93,7 +87,7 @@ def f_neg_or_canonical_form(phi:STL) -> STL:
if isinstance(phi, LinEq): if isinstance(phi, LinEq):
return phi return phi
children = [f_neg_or_canonical_form(s) for s in phi.children()] children = [f_neg_or_canonical_form(s) for s in phi.children]
if isinstance(phi, (And, G)): if isinstance(phi, (And, G)):
children = [Neg(s) for s in children] children = [Neg(s) for s in children]
children = tuple(children) children = tuple(children)
@ -160,7 +154,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
return float(term.coeff)*x[term.id.name][t] return float(term.coeff)*x[term.id.name][t]
output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1])) output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1]))
terms = lens(lineq).terms.each_().get_all() terms = lens(lineq).Each().terms.Each().collect()
for t in times: for t in times:
lhs = sum(eval_term(term, t) for term in terms) lhs = sum(eval_term(term, t) for term in terms)
output[t] = op_lookup[lineq.op](lhs, lineq.const) output[t] = op_lookup[lineq.op](lhs, lineq.const)
@ -172,7 +166,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
def eval_lineqs(phi, x, times=None): def eval_lineqs(phi, x, times=None):
if times is None: if times is None:
times = get_times(x) times = get_times(x)
lineqs = set(lineq_lens(phi).get_all()) lineqs = set(lineq_lens(phi).Each().collect())
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs} return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}