start switch to traces from pandas

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-04-23 12:42:57 -07:00
parent b28898820e
commit e477392664
6 changed files with 55 additions and 20 deletions

View file

@ -141,7 +141,7 @@ class G(ModalOp):
class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST): class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST):
def __repr__(self): def __repr__(self):
return f"({self.arg1} U ({self.arg2}))" return f"({self.arg1}) U ({self.arg2})"
def children(self): def children(self):
return [self.arg1, self.arg2] return [self.arg1, self.arg2]

View file

@ -5,10 +5,13 @@ from functools import singledispatch
import operator as op import operator as op
import numpy as np import numpy as np
import funcy as fn
from lenses import lens from lenses import lens
import stl.ast import stl.ast
oo = float('inf')
@singledispatch @singledispatch
def pointwise_sat(stl): def pointwise_sat(stl):
raise NotImplementedError raise NotImplementedError
@ -27,8 +30,17 @@ def _(stl):
def get_times(x, tau, lo=None, hi=None): def get_times(x, tau, lo=None, hi=None):
indices = x.index if lo is None or hi is None else x[lo:hi].index if lo is None or lo is -oo:
return [min(tau + t2, x.index[-1]) for t2 in indices] lo = min(v.first()[0] for v in x.values())
if hi is None or hi is oo:
hi = max(v.last()[0] for v in x.values())
if lo > hi:
return []
elif hi == lo:
return [lo]
all_times = fn.cat(v.slice(lo, hi).items() for v in x.values())
return sorted(set(fn.pluck(0, all_times)))
@pointwise_sat.register(stl.Until) @pointwise_sat.register(stl.Until)
@ -42,18 +54,26 @@ def _(stl):
return _until return _until
def eval_unary_temporal_op(phi, always=True):
fold = all if always else any
lo, hi = phi.interval
if lo > hi:
retval = True if always else False
return lambda x, t: retval
if hi == lo:
return lambda x, t: f(x, t)
f = pointwise_sat(phi.arg)
return lambda x, t: fold(f(x, tau) for tau in get_times(x, t, lo, hi))
@pointwise_sat.register(stl.F) @pointwise_sat.register(stl.F)
def _(stl): def _(phi):
lo, hi = stl.interval return eval_unary_temporal_op(phi, always=False)
f = pointwise_sat(stl.arg)
return lambda x, t: any(f(x, tau) for tau in get_times(x, t, lo, hi))
@pointwise_sat.register(stl.G) @pointwise_sat.register(stl.G)
def _(stl): def _(phi):
lo, hi = stl.interval return eval_unary_temporal_op(phi, always=True)
f = pointwise_sat(stl.arg)
return lambda x, t: all(f(x, tau) for tau in get_times(x, t, lo, hi))
@pointwise_sat.register(stl.Neg) @pointwise_sat.register(stl.Neg)
@ -89,4 +109,4 @@ def eval_terms(lineq, x, t):
def eval_term(x, t): def eval_term(x, t):
# TODO(lift interpolation much higher) # TODO(lift interpolation much higher)
return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name]) return lambda term: term.coeff*x[term.id.name][t]

View file

@ -44,7 +44,7 @@ U = "U"
interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]" interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]"
const_or_unbound = unbound / const const_or_unbound = unbound / "inf" / const
lineq = terms _ op _ const_or_unbound lineq = terms _ op _ const_or_unbound
term = coeff? var term = coeff? var
@ -84,8 +84,10 @@ class STLVisitor(NodeVisitor):
visit_paren_phi = partialmethod(children_getter, i=2) visit_paren_phi = partialmethod(children_getter, i=2)
def visit_interval(self, _, children): def visit_interval(self, _, children):
_, _, left, _, _, _, right, _, _ = children _, _, (left,), _, _, _, (right,), _, _ = children
return ast.Interval(left[0], right[0]) left = left if left != [] else float("inf")
right = right if right != [] else float("inf")
return ast.Interval(left, right)
def get_text(self, node, _): def get_text(self, node, _):
return node.text return node.text

View file

@ -1,13 +1,15 @@
import stl import stl
import stl.boolean_eval import stl.boolean_eval
import stl.fastboolean_eval import stl.fastboolean_eval
import pandas as pd import traces
from nose2.tools import params from nose2.tools import params
import unittest import unittest
from sympy import Symbol from sympy import Symbol
ex1 = ("2*A > 3", False) ex1 = ("2*A > 3", False)
ex2 = ("F[0, 1](2*A > 3)", True) ex2 = ("F[0, 1](2*A > 3)", True)
ex2 = ("F(2*A > 3)", True)
ex2 = ("F[0, inf](2*A > 3)", True)
ex3 = ("F[1, 0](2*A > 3)", False) ex3 = ("F[1, 0](2*A > 3)", False)
ex4 = ("G[1, 0](2*A > 3)", True) ex4 = ("G[1, 0](2*A > 3)", True)
ex5 = ("(A < 0)", False) ex5 = ("(A < 0)", False)
@ -17,8 +19,11 @@ ex8 = ("G[0, 0.2](C)", False)
ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True) ex9 = ("(F[0, 0.2](C)) and (F[0, 1](2*A > 3))", True)
ex10 = ("(A = 1) U (A = 4)", True) ex10 = ("(A = 1) U (A = 4)", True)
ex11 = ("(A < 5) U (A = 4)", False) ex11 = ("(A < 5) U (A = 4)", False)
x = pd.DataFrame([[1,2, True], [1,4, True], [4,2, False]], index=[0,0.1,0.2], x = {
columns=["A", "B", "C"]) "A": traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)]),
"B": traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)]),
"C": traces.TimeSeries([(0, True), (0.1, True), (0.2, False)]),
}
class TestSTLEval(unittest.TestCase): class TestSTLEval(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9, ex10, ex11) @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9, ex10, ex11)
@ -30,7 +35,7 @@ class TestSTLEval(unittest.TestCase):
self.assertEqual(stl_eval2(x, 0), not r) self.assertEqual(stl_eval2(x, 0), not r)
"""
@params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9, ex10, ex11) @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7, ex8, ex9, ex10, ex11)
def test_fasteval(self, phi_str, _): def test_fasteval(self, phi_str, _):
phi = stl.parse(phi_str) phi = stl.parse(phi_str)
@ -43,3 +48,4 @@ class TestSTLEval(unittest.TestCase):
b_fast2 = stl_evalf2(x, 0) b_fast2 = stl_evalf2(x, 0)
self.assertEqual(b_slow, b_fast) self.assertEqual(b_slow, b_fast)
self.assertEqual(b_fast, not b_fast2) self.assertEqual(b_fast, not b_fast2)
"""

View file

@ -27,8 +27,13 @@ ex4 = ('(x1 > 2) or ((x1 > 2) or (x1 > 2))',
ex5 = ('G[0, b?](x1 > a?)', ex5 = ('G[0, b?](x1 > a?)',
stl.G(i1_, ex1_[1])) stl.G(i1_, ex1_[1]))
ex6 = ('◇[0,1](x1)', stl.F(i1, ex1__[1])) ex6 = ('◇[0,1](x1)', stl.F(i1, ex1__[1]))
ex7 = ('F[0, inf](x)', stl.parse("F(x)"))
class TestSTLParser(unittest.TestCase): class TestSTLParser(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6) @params(ex1, ex2, ex3, ex4, ex5, ex6, ex7)
def test_stl(self, phi_str, phi): def test_stl(self, phi_str, phi):
self.assertEqual(stl.parse(phi_str), phi) self.assertEqual(stl.parse(phi_str), phi)
def test_smoke_test(self):
"""Previously broken parses"""
stl.parse("◇[0,inf]((1*Lane_ID(t) = 1.0) ∧ (◇[0.0,eps?]((◇[eps?,tau1?](¬(1*Lane_ID(t) = 1.0))) ∧ (□[0,tau1?]((1*Lane_ID(t) = 1.0) U (¬(1*Lane_ID(t) = 1.0)))))))")

View file

@ -35,6 +35,8 @@ def _child_lens(psi:STL, focus:Lens) -> STL_Generator:
return return
elif psi is stl.TOP or psi is stl.BOT: elif psi is stl.TOP or psi is stl.BOT:
return return
elif isinstance(psi, stl.ast.Until):
yield from [focus.arg1, focus.arg2]
elif isinstance(psi, NaryOpSTL): elif isinstance(psi, NaryOpSTL):
for j, _ in enumerate(psi.args): for j, _ in enumerate(psi.args):
yield focus.args[j] yield focus.args[j]