payed off testing technical debt + bug fixes + traces based evaluator

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-11-11 17:35:48 -08:00
parent 72639bc59f
commit cba8a83c8e
12 changed files with 302 additions and 172 deletions

View file

@ -35,6 +35,8 @@ class AST(object):
return flatten_binary(And((self, other)), And, TOP, BOT)
def __invert__(self):
if isinstance(self, Neg):
return self.arg
return Neg(self)
@property
@ -68,10 +70,6 @@ class AST(object):
phi = param_lens(self)
return phi.modify(lambda x: float(val.get(x, val.get(str(x), x))))
@property
def terms(self):
return set(terms_lens(self).Each().collect())
@property
def lineqs(self):
return set(lineq_lens(self).Each().collect())
@ -87,6 +85,10 @@ class _Top(AST):
def __repr__(self):
return ""
def __hash__(self):
# TODO: compute hash based on contents
return hash(repr(self))
def __invert__(self):
return BOT
@ -97,6 +99,10 @@ class _Bot(AST):
def __repr__(self):
return ""
def __hash__(self):
# TODO: compute hash based on contents
return hash(repr(self))
def __invert__(self):
return TOP
@ -111,6 +117,10 @@ class AtomicPred(namedtuple("AP", ["id"]), AST):
def __repr__(self):
return f"{self.id}"
def __hash__(self):
# TODO: compute hash based on contents
return hash(repr(self))
@property
def children(self):
return set()
@ -150,10 +160,6 @@ class Interval(namedtuple('I', ['lower', 'upper'])):
def __repr__(self):
return f"[{self.lower},{self.upper}]"
@property
def children(self):
return {self.lower, self.upper}
class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
__slots__ = ()
@ -274,17 +280,17 @@ class Param(namedtuple('Param', ['name']), AST):
return hash(repr(self))
def ast_lens(phi, bind=True, *, pred=None, focus_lens=None, getter=False):
def ast_lens(phi,
bind=True,
*,
pred=lambda _: False,
focus_lens=None,
getter=False):
if focus_lens is None:
def focus_lens(_):
return [lens]
if pred is None:
def pred(_):
return False
child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens)
phi = lenses.bind(phi) if bind else lens
return (phi.Tuple if getter else phi.Fork)(*child_lenses)
@ -297,9 +303,7 @@ def _ast_lens(phi, pred, focus_lens):
if phi is None or not phi.children:
return
if phi is TOP or phi is BOT:
child_lenses = [lens]
elif isinstance(phi, Until):
if isinstance(phi, Until):
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
elif isinstance(phi, NaryOpSTL):
child_lenses = [
@ -324,11 +328,6 @@ def param_lens(phi, *, getter=False):
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
def vars_in_phi(phi):
focus = terms_lens(phi)
return set(focus.tuple_(lens.id, lens.time).get_all())
def type_pred(*args):
ast_types = set(args)
return lambda x: type(x) in ast_types
@ -337,7 +336,3 @@ def type_pred(*args):
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True)
AP_lens = fn.partial(ast_lens, pred=type_pred(AtomicPred), getter=True)
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True)
def terms_lens(phi, bind=True):
return lineq_lens(phi, bind).Each().terms.Each()

View file

@ -5,133 +5,135 @@ import operator as op
from functools import singledispatch
import funcy as fn
import traces
import stl
import stl.ast
from lenses import bind
oo = float('inf')
from stl.utils import const_trace, andf, orf
def pointwise_sat(phi):
TRUE_TRACE = const_trace(True)
FALSE_TRACE = const_trace(False)
def negate_trace(x):
return x.operation(TRUE_TRACE, op.xor)
def pointwise_sat(phi, dt=0.1):
ap_names = [z.id for z in phi.atomic_predicates]
def _eval_stl(x, t):
def _eval_stl(x, t, dt=0.1):
evaluated = stl.utils.eval_lineqs(phi, x)
evaluated.update(fn.project(x, ap_names))
return eval_stl(phi)(evaluated, t)
return bool(eval_stl(phi, dt)(evaluated)[t])
return _eval_stl
@singledispatch
def eval_stl(stl):
def eval_stl(phi, dt):
raise NotImplementedError
@eval_stl.register(stl.Or)
def eval_stl_or(phi):
fs = [eval_stl(arg) for arg in phi.args]
return lambda x, t: any(f(x, t) for f in fs)
def eval_stl_or(phi, dt):
fs = [eval_stl(arg, dt) for arg in phi.args]
def _eval(x):
out = orf(*(f(x) for f in fs))
out.compact()
return out
return _eval
@eval_stl.register(stl.And)
def eval_stl_and(stl):
fs = [eval_stl(arg) for arg in stl.args]
return lambda x, t: all(f(x, t) for f in fs)
def eval_stl_and(phi, dt):
fs = [eval_stl(arg, dt) for arg in phi.args]
def _eval(x):
out = andf(*(f(x) for f in fs))
out.compact()
return out
def get_times(x, tau, lo=None, hi=None):
domain = fn.first(x.values()).domain
if lo is None or lo is -oo:
lo = domain.start()
if hi is None or hi is oo:
hi = domain.end()
end = min(v.domain.end() for v in x.values())
hi = hi + tau if hi + tau <= end else end
lo = lo + tau if lo + tau <= end else end
if lo > hi:
return []
elif hi == lo:
return [lo]
all_times = fn.cat(v.slice(lo, hi).items() for v in x.values())
return sorted(set(fn.pluck(0, all_times)))
return _eval
@eval_stl.register(stl.Until)
def eval_stl_until(stl):
def _until(x, t):
f1, f2 = eval_stl(stl.arg1), eval_stl(stl.arg2)
for tau in get_times(x, t):
if not f1(x, tau):
return f2(x, tau)
return False
return _until
def eval_unary_temporal_op(phi, always=True):
fold = all if always else any
lo, hi = phi.interval
if lo > hi:
retval = True if always else False
return lambda x, t: retval
f = eval_stl(phi.arg)
if hi == lo:
return lambda x, t: f(x, t)
return lambda x, t: fold(f(x, tau) for tau in get_times(x, t, lo, hi))
def eval_stl_until(phi, dt):
raise NotImplementedError
@eval_stl.register(stl.F)
def eval_stl_f(phi):
return eval_unary_temporal_op(phi, always=False)
def eval_stl_f(phi, dt):
phi = ~stl.G(phi.interval, ~phi.arg)
return eval_stl(phi, dt)
@eval_stl.register(stl.G)
def eval_stl_g(phi):
return eval_unary_temporal_op(phi, always=True)
def eval_stl_g(phi, dt):
f = eval_stl(phi.arg, dt)
a, b = phi.interval
def process_intervals(x):
for (start, val), (end, val2) in x.iterintervals():
start2, end2 = start - b, end + a
if end2 > start2:
yield (start2, val)
def _eval(x):
y = f(x)
if len(y) <= 1:
return y
out = traces.TimeSeries(process_intervals(y))
out.compact()
return out
return _eval
@eval_stl.register(stl.Neg)
def eval_stl_neg(stl):
f = eval_stl(stl.arg)
return lambda x, t: not f(x, t)
def eval_stl_neg(phi, dt):
f = eval_stl(phi.arg, dt)
def _eval(x):
out = negate_trace(f(x))
out.compact()
return out
return _eval
op_lookup = {
">": op.gt,
">=": op.ge,
"<": op.lt,
"<=": op.le,
"=": op.eq,
}
@eval_stl.register(stl.ast.Next)
def eval_stl_next(phi, dt):
f = eval_stl(phi.arg, dt)
def _eval(x):
out = traces.TimeSeries((t + dt, v) for t, v in f(x))
out.compact()
return out
return _eval
@eval_stl.register(stl.AtomicPred)
def eval_stl_ap(stl):
return lambda x, t: x[str(stl.id)][t]
def eval_stl_ap(phi, _):
def _eval(x):
out = x[str(phi.id)]
out.compact()
return out
return _eval
@eval_stl.register(type(stl.TOP))
def eval_stl_top(_):
return lambda *_: True
def eval_stl_top(_, _1):
return lambda *_: TRUE_TRACE
@eval_stl.register(type(stl.BOT))
def eval_stl_bot(_):
return lambda *_: False
@eval_stl.register(stl.LinEq)
def eval_stl_lineq(lineq):
return lambda x, t: x[lineq][t]
def eval_terms(lineq, x, t):
terms = bind(lineq).terms.Each().collect()
return sum(eval_term(term, x, t) for term in terms)
def eval_term(term, x, t):
return float(term.coeff) * x[term.id][t]
def eval_stl_bot(_, _1):
return lambda *_: FALSE_TRACE

View file

@ -1,10 +1,49 @@
import operator as op
from functools import reduce, singledispatch
from operator import and_, or_
import funcy as fn
from bitarray import bitarray
from lenses import bind
import stl.ast
from stl.boolean_eval import eval_terms, get_times, op_lookup
oo = float('inf')
op_lookup = {
">": op.gt,
">=": op.ge,
"<": op.lt,
"<=": op.le,
"=": op.eq,
}
def eval_terms(lineq, x, t):
terms = bind(lineq).terms.Each().collect()
return sum(eval_term(term, x, t) for term in terms)
def eval_term(term, x, t):
return float(term.coeff) * x[term.id][t]
def get_times(x, tau, lo=None, hi=None):
domain = fn.first(x.values()).domain
if lo is None or lo is -oo:
lo = domain.start()
if hi is None or hi is oo:
hi = domain.end()
end = min(v.domain.end() for v in x.values())
hi = hi + tau if hi + tau <= end else end
lo = lo + tau if lo + tau <= end else end
if lo > hi:
return []
elif hi == lo:
return [lo]
all_times = fn.cat(v.slice(lo, hi).items() for v in x.values())
return sorted(set(fn.pluck(0, all_times)))
def pointwise_sat(stl):

View file

@ -1,13 +1,14 @@
from stl.fastboolean_eval import pointwise_sat
from stl import pointwise_sat
def featurize_trace(phi, x):
def ordered_evaluator(phi):
params = {ap.name for ap in phi.params}
order = tuple(params)
def vec_to_dict(theta):
return {k: v for k, v in zip(order, theta)}
def eval_phi(theta):
def eval_phi(theta, x):
return pointwise_sat(phi.set_params(vec_to_dict(theta)))(x, 0)
return eval_phi
return eval_phi, order

View file

@ -5,21 +5,18 @@ import stl
GRAMMAR = {
'phi': (('Unary', '(', 'phi', ')'), ('(', 'phi', ')', 'Binary', '(', 'phi',
')'), ('AP', ), ('LINEQ', )),
')'), ('AP', ), ('LINEQ', ), ('', ),
('', )),
'Unary': (('~', ), ('G', 'Interval'), ('F', 'Interval'), ('X', )),
'Interval': (('',), ('[1, 3]',)),
'Binary': ((' | ', ), (' & ', ), (' U ',)),
'Interval': (('', ), ('[1, 3]', )),
'Binary': ((' | ', ), (' & ', ), (' U ', ), (' -> ', ), (' <-> ',),
(' ^ ',)),
'AP': (('AP1', ), ('AP2', ), ('AP3', ), ('AP4', ), ('AP5', )),
'LINEQ': (('x > 4', ), ('y < 2', ), ('y >= 3', ), ('x + y >= 2',)),
'LINEQ': (('x > 4', ), ('y < 2', ), ('y >= 3', ), ('x + 2.0y >= 2', )),
}
def to_stl(term):
return stl.parse(''.join(term))
SignalTemporalLogicStrategy = st.builds(to_stl,
SignalTemporalLogicStrategy = st.builds(lambda term: stl.parse(''.join(term)),
ContextFreeGrammarStrategy(
GRAMMAR,
max_length=27,
max_length=35,
start='phi'))

View file

@ -9,11 +9,11 @@ from functools import partialmethod
from lenses import bind
from parsimonious import Grammar, NodeVisitor
from stl import ast
from stl.utils import alw, env, iff, implies, xor
from stl.utils import iff, implies, xor
STL_GRAMMAR = Grammar(u'''
phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and
/ implies / xor / iff / paren_phi)
/ implies / xor / iff / paren_phi / bot / top)
paren_phi = "(" __ phi __ ")"
@ -45,11 +45,14 @@ terms = (term __ pm __ terms) / term
var = id
AP = ~r"[a-zA-z\d]+"
bot = ""
top = ""
pm = "+" / "-"
dt = "dt"
unbound = id "?"
id = ~r"[a-zA-z\d]+"
const = ~r"[-+]?\d*\.\d+|\d+"
const = ~r"[-+]?(\d*\.\d+|\d+)"
op = ">=" / "<=" / "<" / ">" / "="
_ = ~r"\s"+
__ = ~r"\s"*
@ -73,6 +76,12 @@ class STLVisitor(NodeVisitor):
visit_phi = partialmethod(children_getter, i=0)
visit_paren_phi = partialmethod(children_getter, i=2)
def visit_bot(self, *_):
return ast.BOT
def visit_top(self, *_):
return ast.TOP
def visit_interval(self, _, children):
_, _, (left, ), _, _, _, (right, ), _, _ = children
left = left if left != [] else float("inf")
@ -114,10 +123,6 @@ class STLVisitor(NodeVisitor):
phi1, _, _, _, phi2 = children
return ast.Until(phi1, phi2)
def visit_timed_until(self, _, children):
phi, _, _, (lo, hi), _, psi = children
return env(psi, lo=lo, hi=hi) & alw(ast.Until(phi, psi), lo=0, hi=lo)
def visit_id(self, name, _):
return name.text
@ -148,7 +153,7 @@ class STLVisitor(NodeVisitor):
return ast.AtomicPred(self.visit_id(*args))
def visit_neg(self, _, children):
return ast.Neg(children[1])
return ~children[1]
def visit_next(self, _, children):
return ast.Next(children[1])

View file

@ -1,22 +1,47 @@
import unittest
import stl
from stl.hypothesis import SignalTemporalLogicStrategy
from hypothesis import given
class TestSTLAST(unittest.TestCase):
def test_and(self):
phi = stl.parse("x")
self.assertEqual(stl.TOP, stl.TOP | phi)
self.assertEqual(stl.BOT, stl.BOT & phi)
self.assertEqual(stl.TOP, phi | stl.TOP)
self.assertEqual(stl.BOT, phi & stl.BOT)
self.assertEqual(phi, phi & stl.TOP)
self.assertEqual(phi, phi | stl.BOT)
self.assertEqual(stl.TOP, stl.TOP & stl.TOP)
self.assertEqual(stl.BOT, stl.BOT | stl.BOT)
self.assertEqual(stl.TOP, stl.TOP | stl.BOT)
self.assertEqual(stl.BOT, stl.TOP & stl.BOT)
self.assertEqual(~stl.BOT, stl.TOP)
self.assertEqual(~stl.TOP, stl.BOT)
self.assertEqual(~~stl.BOT, stl.BOT)
self.assertEqual(~~stl.TOP, stl.TOP)
@given(SignalTemporalLogicStrategy)
def test_identities(phi):
assert stl.TOP == stl.TOP | phi
assert stl.BOT == stl.BOT & phi
assert stl.TOP == phi | stl.TOP
assert stl.BOT == phi & stl.BOT
assert phi == phi & stl.TOP
assert phi == phi | stl.BOT
assert stl.TOP == stl.TOP & stl.TOP
assert stl.BOT == stl.BOT | stl.BOT
assert stl.TOP == stl.TOP | stl.BOT
assert stl.BOT == stl.TOP & stl.BOT
assert ~stl.BOT == stl.TOP
assert ~stl.TOP == stl.BOT
assert ~~stl.BOT == stl.BOT
assert ~~stl.TOP == stl.TOP
assert (phi & phi) & phi == phi & (phi & phi)
assert (phi | phi) | phi == phi | (phi | phi)
assert ~~phi == phi
def test_lineqs_unittest():
phi = stl.parse('(G[0, 1](x + y > a?)) & (F[1,2](z - x > 0))')
assert len(phi.lineqs) == 2
assert phi.lineqs == {stl.parse('x + y > a?'), stl.parse('z - x > 0')}
phi = stl.parse('(G[0, 1](x + y > a?)) U (F[1,2](z - x > 0))')
assert len(phi.lineqs) == 2
assert phi.lineqs == {stl.parse('x + y > a?'), stl.parse('z - x > 0')}
phi = stl.parse('G(⊥)')
assert phi.lineqs == set()
phi = stl.parse('F()')
assert phi.lineqs == set()
def test_walk():
phi = stl.parse(
'((G[0, 1](x + y > a?)) & (F[1,2](z - x > 0))) | ((X(AP1)) U (AP2))')
assert len(list((~phi).walk())) == 11

View file

@ -1,11 +1,12 @@
import hypothesis.strategies as st
import traces
from hypothesis import given
from hypothesis import given # , settings, Verbosity, Phase
import stl
import stl.boolean_eval
import stl.fastboolean_eval
# from stl.hypothesis import SignalTemporalLogicStrategy
"""
TODO: property based test that fasteval should be the same as slow
@ -33,7 +34,11 @@ x = {
}
@given(st.just(stl.BOT))
@given(st.just(stl.ast.Next(stl.BOT) | stl.ast.Next(stl.TOP)))
# @given(SignalTemporalLogicStrategy)
# @settings(max_shrinks=0, verbosity=Verbosity.verbose,
# perform_health_check=False,
# phases=[Phase.generate])
def test_boolean_identities(phi):
stl_eval = stl.boolean_eval.pointwise_sat(phi)
stl_eval2 = stl.boolean_eval.pointwise_sat(~phi)
@ -47,6 +52,24 @@ def test_boolean_identities(phi):
stl_eval6 = stl.boolean_eval.pointwise_sat(phi | ~phi)
assert stl_eval6(x, 0)
# phi2 = stl.alw(stl.ast.Next(phi))
# phi3 = stl.ast.Next(stl.alw(phi))
# stl_eval7 = stl.boolean_eval.pointwise_sat(phi2)
# stl_eval8 = stl.boolean_eval.pointwise_sat(phi3)
# assert stl_eval7(x, 0) == stl_eval8(x, 0)
stl_eval9 = stl.boolean_eval.pointwise_sat(stl.ast.Next(phi))
stl_eval10 = stl.boolean_eval.pointwise_sat(~stl.ast.Next(phi))
assert stl_eval9(x, 0) != stl_eval10(x, 0)
phi4 = stl.parse('~(AP4)')
stl_eval11 = stl.boolean_eval.pointwise_sat(phi4)
assert stl_eval11(x, 0)
phi5 = stl.parse('G[0.1, 0.03](~(AP4))')
stl_eval12 = stl.boolean_eval.pointwise_sat(phi5)
assert stl_eval12(x, 0)
@given(st.just(stl.BOT))
def test_temporal_identities(phi):

20
stl/test_load.py Normal file
View file

@ -0,0 +1,20 @@
import pandas as pd
from stl.load import from_pandas
DATA = pd.DataFrame(
data={
'AP1': [True, False, True],
'x': [0, 0, 0.1],
'y': [-1, -1, 0],
'z': [2, 3, 1],
},
index=[0, 1, 2],
)
def test_from_pandas():
x = from_pandas(DATA)
assert x['x'][0] == 0
assert x['x'][0.2] == 0
assert not x['AP1'][1.4]

14
stl/test_params.py Normal file
View file

@ -0,0 +1,14 @@
import stl
import hypothesis.strategies as st
from hypothesis import given
@given(st.integers(), st.integers(), st.integers())
def test_params1(a, b, c):
phi = stl.parse('G[a?, b?](x > c?)')
assert {x.name for x in phi.params} == {'a?', 'b?', 'c?'}
phi2 = phi.set_params({'a?': a, 'b?': b, 'c?': c})
assert phi2.params == set()
assert phi2 == stl.parse(f'G[{a}, {b}](x > {c})')

View file

@ -14,3 +14,9 @@ def test_invertable_repr(phi):
@given(SignalTemporalLogicStrategy)
def test_hash_inheritance(phi):
assert hash(repr(phi)) == hash(phi)
def test_sugar_smoke_test():
stl.parse('(x) <-> (x)')
stl.parse('(x) -> (x)')
stl.parse('(x) ^ (x)')

View file

@ -2,7 +2,7 @@ import operator as op
from functools import reduce
import traces
from lenses import lens, bind
from lenses import bind
import stl.ast
from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens)
@ -67,38 +67,37 @@ def get_times(x):
return sorted(times)
def eval_lineq(lineq, x, compact=True):
def eval_term(term, t):
return float(term.coeff) * x[term.id.name][t]
def const_trace(x):
oo = float('inf')
return traces.TimeSeries([(-oo, x)])
terms = lens(lineq).Each().terms.Each().collect()
def f(t):
lhs = sum(eval_term(term, t) for term in terms)
return op_lookup[lineq.op](lhs, lineq.const)
output = traces.TimeSeries(map(f, x), domain=x.domain)
def eval_lineq(lineq, x, domain, compact=True):
lhs = sum(const_trace(term.coeff)*x[term.id] for term in lineq.terms)
compare = op_lookup.get(lineq.op)
output = lhs.operation(const_trace(lineq.const), compare)
if compact:
output.compact()
return output
def eval_lineqs(phi, x, times=None):
if times is None:
times = get_times(x)
def eval_lineqs(phi, x):
lineqs = phi.lineqs
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs}
start = max(y.domain.start() for y in x.values())
end = min(y.domain.end() for y in x.values())
domain = traces.Domain(start, end)
return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs}
# EDSL
def alw(phi, *, lo, hi):
def alw(phi, *, lo=0, hi=float('inf')):
return G(Interval(lo, hi), phi)
def env(phi, *, lo, hi):
def env(phi, *, lo=0, hi=float('inf')):
return F(Interval(lo, hi), phi)
@ -124,3 +123,7 @@ def xor(x, y):
def iff(x, y):
return (x & y) | (~x & ~y)
def timed_until(phi, psi, lo, hi):
return env(psi, lo=lo, hi=hi) & alw(stl.ast.Until(phi, psi), lo=0, hi=lo)