start switch to traces from pandas

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-04-23 12:42:57 -07:00
parent b28898820e
commit e477392664
6 changed files with 55 additions and 20 deletions

View file

@ -141,7 +141,7 @@ class G(ModalOp):
class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST):
def __repr__(self):
return f"({self.arg1} U ({self.arg2}))"
return f"({self.arg1}) U ({self.arg2})"
def children(self):
return [self.arg1, self.arg2]

View file

@ -5,10 +5,13 @@ from functools import singledispatch
import operator as op
import numpy as np
import funcy as fn
from lenses import lens
import stl.ast
oo = float('inf')
@singledispatch
def pointwise_sat(stl):
raise NotImplementedError
@ -27,8 +30,17 @@ def _(stl):
def get_times(x, tau, lo=None, hi=None):
indices = x.index if lo is None or hi is None else x[lo:hi].index
return [min(tau + t2, x.index[-1]) for t2 in indices]
if lo is None or lo is -oo:
lo = min(v.first()[0] for v in x.values())
if hi is None or hi is oo:
hi = max(v.last()[0] for v in x.values())
if lo > hi:
return []
elif hi == lo:
return [lo]
all_times = fn.cat(v.slice(lo, hi).items() for v in x.values())
return sorted(set(fn.pluck(0, all_times)))
@pointwise_sat.register(stl.Until)
@ -42,18 +54,26 @@ def _(stl):
return _until
def eval_unary_temporal_op(phi, always=True):
fold = all if always else any
lo, hi = phi.interval
if lo > hi:
retval = True if always else False
return lambda x, t: retval
if hi == lo:
return lambda x, t: f(x, t)
f = pointwise_sat(phi.arg)
return lambda x, t: fold(f(x, tau) for tau in get_times(x, t, lo, hi))
@pointwise_sat.register(stl.F)
def _(stl):
lo, hi = stl.interval
f = pointwise_sat(stl.arg)
return lambda x, t: any(f(x, tau) for tau in get_times(x, t, lo, hi))
def _(phi):
return eval_unary_temporal_op(phi, always=False)
@pointwise_sat.register(stl.G)
def _(stl):
lo, hi = stl.interval
f = pointwise_sat(stl.arg)
return lambda x, t: all(f(x, tau) for tau in get_times(x, t, lo, hi))
def _(phi):
return eval_unary_temporal_op(phi, always=True)
@pointwise_sat.register(stl.Neg)
@ -89,4 +109,4 @@ def eval_terms(lineq, x, t):
def eval_term(x, t):
# TODO(lift interpolation much higher)
return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name])
return lambda term: term.coeff*x[term.id.name][t]

View file

@ -44,7 +44,7 @@ U = "U"
interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]"
const_or_unbound = unbound / const
const_or_unbound = unbound / "inf" / const
lineq = terms _ op _ const_or_unbound
term = coeff? var
@ -84,8 +84,10 @@ class STLVisitor(NodeVisitor):
visit_paren_phi = partialmethod(children_getter, i=2)
def visit_interval(self, _, children):
_, _, left, _, _, _, right, _, _ = children
return ast.Interval(left[0], right[0])
_, _, (left,), _, _, _, (right,), _, _ = children
left = left if left != [] else float("inf")
right = right if right != [] else float("inf")
return ast.Interval(left, right)
def get_text(self, node, _):
return node.text

View file

@ -1,13 +1,15 @@
import stl
import stl.boolean_eval
import stl.fastboolean_eval
import pandas as pd
import traces
from nose2.tools import params
import unittest
from sympy import Symbol
ex1 = ("2*A > 3", False)
ex2 = ("F[0, 1](2*A > 3)", True)
ex2 = ("F(2*A > 3)", True)
ex2 = ("F[0, inf](2*A > 3)", True)
ex3 = ("F[1, 0](2*A > 3)", False)
ex4 = ("G[1, 0](2*A > 3)", True)
ex5 = ("(A < 0)", False)
@ -17,8 +19,11 @@ ex8 = ("G[0, 0.2](C)", False)
ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True)
ex10 = ("(A = 1) U (A = 4)", True)
ex11 = ("(A < 5) U (A = 4)", False)
x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2],
columns=["A", "B", "C"])
x = {
"A": traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)]),
"B": traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)]),
"C": traces.TimeSeries([(0, True), (0.1, True), (0.2, False)]),
}
class TestSTLEval(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9, ex10, ex11)
@ -30,7 +35,7 @@ class TestSTLEval(unittest.TestCase):
self.assertEqual(stl_eval2(x, 0), not r)
"""
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9, ex10, ex11)
def test_fasteval(self, phi_str, _):
phi = stl.parse(phi_str)
@ -43,3 +48,4 @@ class TestSTLEval(unittest.TestCase):
b_fast2 = stl_evalf2(x, 0)
self.assertEqual(b_slow, b_fast)
self.assertEqual(b_fast, not b_fast2)
"""

View file

@ -27,8 +27,13 @@ ex4 = ('(x1 > 2) or ((x1 > 2) or (x1 > 2))',
ex5 = ('G[0, b?](x1 > a?)',
stl.G(i1_, ex1_[1]))
ex6 = ('◇[0,1](x1)', stl.F(i1, ex1__[1]))
ex7 = ('F[0, inf](x)', stl.parse("F(x)"))
class TestSTLParser(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6)
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7)
def test_stl(self, phi_str, phi):
self.assertEqual(stl.parse(phi_str), phi)
def test_smoke_test(self):
"""Previously broken parses"""
stl.parse("◇[0,inf]((1*Lane_ID(t) = 1.0) ∧ (◇[0.0,eps?]((◇[eps?,tau1?](¬(1*Lane_ID(t) = 1.0))) ∧ (□[0,tau1?]((1*Lane_ID(t) = 1.0) U (¬(1*Lane_ID(t) = 1.0)))))))")

View file

@ -35,6 +35,8 @@ def _child_lens(psi:STL, focus:Lens) -> STL_Generator:
return
elif psi is stl.TOP or psi is stl.BOT:
return
elif isinstance(psi, stl.ast.Until):
yield from [focus.arg1, focus.arg2]
elif isinstance(psi, NaryOpSTL):
for j, _ in enumerate(psi.args):
yield focus.args[j]