payed off testing technical debt + bug fixes + traces based evaluator

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-11-11 17:35:48 -08:00
parent 72639bc59f
commit cba8a83c8e
12 changed files with 302 additions and 172 deletions

View file

@ -35,6 +35,8 @@ class AST(object):
return flatten_binary(And((self, other)), And, TOP, BOT) return flatten_binary(And((self, other)), And, TOP, BOT)
def __invert__(self): def __invert__(self):
if isinstance(self, Neg):
return self.arg
return Neg(self) return Neg(self)
@property @property
@ -68,10 +70,6 @@ class AST(object):
phi = param_lens(self) phi = param_lens(self)
return phi.modify(lambda x: float(val.get(x, val.get(str(x), x)))) return phi.modify(lambda x: float(val.get(x, val.get(str(x), x))))
@property
def terms(self):
return set(terms_lens(self).Each().collect())
@property @property
def lineqs(self): def lineqs(self):
return set(lineq_lens(self).Each().collect()) return set(lineq_lens(self).Each().collect())
@ -87,6 +85,10 @@ class _Top(AST):
def __repr__(self): def __repr__(self):
return "" return ""
def __hash__(self):
# TODO: compute hash based on contents
return hash(repr(self))
def __invert__(self): def __invert__(self):
return BOT return BOT
@ -97,6 +99,10 @@ class _Bot(AST):
def __repr__(self): def __repr__(self):
return "" return ""
def __hash__(self):
# TODO: compute hash based on contents
return hash(repr(self))
def __invert__(self): def __invert__(self):
return TOP return TOP
@ -111,6 +117,10 @@ class AtomicPred(namedtuple("AP", ["id"]), AST):
def __repr__(self): def __repr__(self):
return f"{self.id}" return f"{self.id}"
def __hash__(self):
# TODO: compute hash based on contents
return hash(repr(self))
@property @property
def children(self): def children(self):
return set() return set()
@ -150,10 +160,6 @@ class Interval(namedtuple('I', ['lower', 'upper'])):
def __repr__(self): def __repr__(self):
return f"[{self.lower},{self.upper}]" return f"[{self.lower},{self.upper}]"
@property
def children(self):
return {self.lower, self.upper}
class NaryOpSTL(namedtuple('NaryOp', ['args']), AST): class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
__slots__ = () __slots__ = ()
@ -274,17 +280,17 @@ class Param(namedtuple('Param', ['name']), AST):
return hash(repr(self)) return hash(repr(self))
def ast_lens(phi, bind=True, *, pred=None, focus_lens=None, getter=False): def ast_lens(phi,
bind=True,
*,
pred=lambda _: False,
focus_lens=None,
getter=False):
if focus_lens is None: if focus_lens is None:
def focus_lens(_): def focus_lens(_):
return [lens] return [lens]
if pred is None:
def pred(_):
return False
child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens) child_lenses = _ast_lens(phi, pred=pred, focus_lens=focus_lens)
phi = lenses.bind(phi) if bind else lens phi = lenses.bind(phi) if bind else lens
return (phi.Tuple if getter else phi.Fork)(*child_lenses) return (phi.Tuple if getter else phi.Fork)(*child_lenses)
@ -297,9 +303,7 @@ def _ast_lens(phi, pred, focus_lens):
if phi is None or not phi.children: if phi is None or not phi.children:
return return
if phi is TOP or phi is BOT: if isinstance(phi, Until):
child_lenses = [lens]
elif isinstance(phi, Until):
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')] child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
elif isinstance(phi, NaryOpSTL): elif isinstance(phi, NaryOpSTL):
child_lenses = [ child_lenses = [
@ -324,11 +328,6 @@ def param_lens(phi, *, getter=False):
phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter) phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens, getter=getter)
def vars_in_phi(phi):
focus = terms_lens(phi)
return set(focus.tuple_(lens.id, lens.time).get_all())
def type_pred(*args): def type_pred(*args):
ast_types = set(args) ast_types = set(args)
return lambda x: type(x) in ast_types return lambda x: type(x) in ast_types
@ -337,7 +336,3 @@ def type_pred(*args):
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True) lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq), getter=True)
AP_lens = fn.partial(ast_lens, pred=type_pred(AtomicPred), getter=True) AP_lens = fn.partial(ast_lens, pred=type_pred(AtomicPred), getter=True)
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True) and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or), getter=True)
def terms_lens(phi, bind=True):
return lineq_lens(phi, bind).Each().terms.Each()

View file

@ -5,133 +5,135 @@ import operator as op
from functools import singledispatch from functools import singledispatch
import funcy as fn import funcy as fn
import traces
import stl import stl
import stl.ast import stl.ast
from lenses import bind from stl.utils import const_trace, andf, orf
oo = float('inf')
def pointwise_sat(phi): TRUE_TRACE = const_trace(True)
FALSE_TRACE = const_trace(False)
def negate_trace(x):
return x.operation(TRUE_TRACE, op.xor)
def pointwise_sat(phi, dt=0.1):
ap_names = [z.id for z in phi.atomic_predicates] ap_names = [z.id for z in phi.atomic_predicates]
def _eval_stl(x, t): def _eval_stl(x, t, dt=0.1):
evaluated = stl.utils.eval_lineqs(phi, x) evaluated = stl.utils.eval_lineqs(phi, x)
evaluated.update(fn.project(x, ap_names)) evaluated.update(fn.project(x, ap_names))
return eval_stl(phi)(evaluated, t) return bool(eval_stl(phi, dt)(evaluated)[t])
return _eval_stl return _eval_stl
@singledispatch @singledispatch
def eval_stl(stl): def eval_stl(phi, dt):
raise NotImplementedError raise NotImplementedError
@eval_stl.register(stl.Or) @eval_stl.register(stl.Or)
def eval_stl_or(phi): def eval_stl_or(phi, dt):
fs = [eval_stl(arg) for arg in phi.args] fs = [eval_stl(arg, dt) for arg in phi.args]
return lambda x, t: any(f(x, t) for f in fs)
def _eval(x):
out = orf(*(f(x) for f in fs))
out.compact()
return out
return _eval
@eval_stl.register(stl.And) @eval_stl.register(stl.And)
def eval_stl_and(stl): def eval_stl_and(phi, dt):
fs = [eval_stl(arg) for arg in stl.args] fs = [eval_stl(arg, dt) for arg in phi.args]
return lambda x, t: all(f(x, t) for f in fs)
def _eval(x):
out = andf(*(f(x) for f in fs))
out.compact()
return out
def get_times(x, tau, lo=None, hi=None): return _eval
domain = fn.first(x.values()).domain
if lo is None or lo is -oo:
lo = domain.start()
if hi is None or hi is oo:
hi = domain.end()
end = min(v.domain.end() for v in x.values())
hi = hi + tau if hi + tau <= end else end
lo = lo + tau if lo + tau <= end else end
if lo > hi:
return []
elif hi == lo:
return [lo]
all_times = fn.cat(v.slice(lo, hi).items() for v in x.values())
return sorted(set(fn.pluck(0, all_times)))
@eval_stl.register(stl.Until) @eval_stl.register(stl.Until)
def eval_stl_until(stl): def eval_stl_until(phi, dt):
def _until(x, t): raise NotImplementedError
f1, f2 = eval_stl(stl.arg1), eval_stl(stl.arg2)
for tau in get_times(x, t):
if not f1(x, tau):
return f2(x, tau)
return False
return _until
def eval_unary_temporal_op(phi, always=True):
fold = all if always else any
lo, hi = phi.interval
if lo > hi:
retval = True if always else False
return lambda x, t: retval
f = eval_stl(phi.arg)
if hi == lo:
return lambda x, t: f(x, t)
return lambda x, t: fold(f(x, tau) for tau in get_times(x, t, lo, hi))
@eval_stl.register(stl.F) @eval_stl.register(stl.F)
def eval_stl_f(phi): def eval_stl_f(phi, dt):
return eval_unary_temporal_op(phi, always=False) phi = ~stl.G(phi.interval, ~phi.arg)
return eval_stl(phi, dt)
@eval_stl.register(stl.G) @eval_stl.register(stl.G)
def eval_stl_g(phi): def eval_stl_g(phi, dt):
return eval_unary_temporal_op(phi, always=True) f = eval_stl(phi.arg, dt)
a, b = phi.interval
def process_intervals(x):
for (start, val), (end, val2) in x.iterintervals():
start2, end2 = start - b, end + a
if end2 > start2:
yield (start2, val)
def _eval(x):
y = f(x)
if len(y) <= 1:
return y
out = traces.TimeSeries(process_intervals(y))
out.compact()
return out
return _eval
@eval_stl.register(stl.Neg) @eval_stl.register(stl.Neg)
def eval_stl_neg(stl): def eval_stl_neg(phi, dt):
f = eval_stl(stl.arg) f = eval_stl(phi.arg, dt)
return lambda x, t: not f(x, t)
def _eval(x):
out = negate_trace(f(x))
out.compact()
return out
return _eval
op_lookup = { @eval_stl.register(stl.ast.Next)
">": op.gt, def eval_stl_next(phi, dt):
">=": op.ge, f = eval_stl(phi.arg, dt)
"<": op.lt,
"<=": op.le, def _eval(x):
"=": op.eq, out = traces.TimeSeries((t + dt, v) for t, v in f(x))
} out.compact()
return out
return _eval
@eval_stl.register(stl.AtomicPred) @eval_stl.register(stl.AtomicPred)
def eval_stl_ap(stl): def eval_stl_ap(phi, _):
return lambda x, t: x[str(stl.id)][t] def _eval(x):
out = x[str(phi.id)]
out.compact()
return out
return _eval
@eval_stl.register(type(stl.TOP)) @eval_stl.register(type(stl.TOP))
def eval_stl_top(_): def eval_stl_top(_, _1):
return lambda *_: True return lambda *_: TRUE_TRACE
@eval_stl.register(type(stl.BOT)) @eval_stl.register(type(stl.BOT))
def eval_stl_bot(_): def eval_stl_bot(_, _1):
return lambda *_: False return lambda *_: FALSE_TRACE
@eval_stl.register(stl.LinEq)
def eval_stl_lineq(lineq):
return lambda x, t: x[lineq][t]
def eval_terms(lineq, x, t):
terms = bind(lineq).terms.Each().collect()
return sum(eval_term(term, x, t) for term in terms)
def eval_term(term, x, t):
return float(term.coeff) * x[term.id][t]

View file

@ -1,10 +1,49 @@
import operator as op
from functools import reduce, singledispatch from functools import reduce, singledispatch
from operator import and_, or_ from operator import and_, or_
import funcy as fn
from bitarray import bitarray from bitarray import bitarray
from lenses import bind
import stl.ast import stl.ast
from stl.boolean_eval import eval_terms, get_times, op_lookup
oo = float('inf')
op_lookup = {
">": op.gt,
">=": op.ge,
"<": op.lt,
"<=": op.le,
"=": op.eq,
}
def eval_terms(lineq, x, t):
terms = bind(lineq).terms.Each().collect()
return sum(eval_term(term, x, t) for term in terms)
def eval_term(term, x, t):
return float(term.coeff) * x[term.id][t]
def get_times(x, tau, lo=None, hi=None):
domain = fn.first(x.values()).domain
if lo is None or lo is -oo:
lo = domain.start()
if hi is None or hi is oo:
hi = domain.end()
end = min(v.domain.end() for v in x.values())
hi = hi + tau if hi + tau <= end else end
lo = lo + tau if lo + tau <= end else end
if lo > hi:
return []
elif hi == lo:
return [lo]
all_times = fn.cat(v.slice(lo, hi).items() for v in x.values())
return sorted(set(fn.pluck(0, all_times)))
def pointwise_sat(stl): def pointwise_sat(stl):

View file

@ -1,13 +1,14 @@
from stl.fastboolean_eval import pointwise_sat from stl import pointwise_sat
def featurize_trace(phi, x):
def ordered_evaluator(phi):
params = {ap.name for ap in phi.params} params = {ap.name for ap in phi.params}
order = tuple(params) order = tuple(params)
def vec_to_dict(theta): def vec_to_dict(theta):
return {k: v for k, v in zip(order, theta)} return {k: v for k, v in zip(order, theta)}
def eval_phi(theta): def eval_phi(theta, x):
return pointwise_sat(phi.set_params(vec_to_dict(theta)))(x, 0) return pointwise_sat(phi.set_params(vec_to_dict(theta)))(x, 0)
return eval_phi return eval_phi, order

View file

@ -5,21 +5,18 @@ import stl
GRAMMAR = { GRAMMAR = {
'phi': (('Unary', '(', 'phi', ')'), ('(', 'phi', ')', 'Binary', '(', 'phi', 'phi': (('Unary', '(', 'phi', ')'), ('(', 'phi', ')', 'Binary', '(', 'phi',
')'), ('AP', ), ('LINEQ', )), ')'), ('AP', ), ('LINEQ', ), ('', ),
('', )),
'Unary': (('~', ), ('G', 'Interval'), ('F', 'Interval'), ('X', )), 'Unary': (('~', ), ('G', 'Interval'), ('F', 'Interval'), ('X', )),
'Interval': (('', ), ('[1, 3]', )), 'Interval': (('', ), ('[1, 3]', )),
'Binary': ((' | ', ), (' & ', ), (' U ',)), 'Binary': ((' | ', ), (' & ', ), (' U ', ), (' -> ', ), (' <-> ',),
(' ^ ',)),
'AP': (('AP1', ), ('AP2', ), ('AP3', ), ('AP4', ), ('AP5', )), 'AP': (('AP1', ), ('AP2', ), ('AP3', ), ('AP4', ), ('AP5', )),
'LINEQ': (('x > 4', ), ('y < 2', ), ('y >= 3', ), ('x + y >= 2',)), 'LINEQ': (('x > 4', ), ('y < 2', ), ('y >= 3', ), ('x + 2.0y >= 2', )),
} }
SignalTemporalLogicStrategy = st.builds(lambda term: stl.parse(''.join(term)),
def to_stl(term):
return stl.parse(''.join(term))
SignalTemporalLogicStrategy = st.builds(to_stl,
ContextFreeGrammarStrategy( ContextFreeGrammarStrategy(
GRAMMAR, GRAMMAR,
max_length=27, max_length=35,
start='phi')) start='phi'))

View file

@ -9,11 +9,11 @@ from functools import partialmethod
from lenses import bind from lenses import bind
from parsimonious import Grammar, NodeVisitor from parsimonious import Grammar, NodeVisitor
from stl import ast from stl import ast
from stl.utils import alw, env, iff, implies, xor from stl.utils import iff, implies, xor
STL_GRAMMAR = Grammar(u''' STL_GRAMMAR = Grammar(u'''
phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and
/ implies / xor / iff / paren_phi) / implies / xor / iff / paren_phi / bot / top)
paren_phi = "(" __ phi __ ")" paren_phi = "(" __ phi __ ")"
@ -45,11 +45,14 @@ terms = (term __ pm __ terms) / term
var = id var = id
AP = ~r"[a-zA-z\d]+" AP = ~r"[a-zA-z\d]+"
bot = ""
top = ""
pm = "+" / "-" pm = "+" / "-"
dt = "dt" dt = "dt"
unbound = id "?" unbound = id "?"
id = ~r"[a-zA-z\d]+" id = ~r"[a-zA-z\d]+"
const = ~r"[-+]?\d*\.\d+|\d+" const = ~r"[-+]?(\d*\.\d+|\d+)"
op = ">=" / "<=" / "<" / ">" / "=" op = ">=" / "<=" / "<" / ">" / "="
_ = ~r"\s"+ _ = ~r"\s"+
__ = ~r"\s"* __ = ~r"\s"*
@ -73,6 +76,12 @@ class STLVisitor(NodeVisitor):
visit_phi = partialmethod(children_getter, i=0) visit_phi = partialmethod(children_getter, i=0)
visit_paren_phi = partialmethod(children_getter, i=2) visit_paren_phi = partialmethod(children_getter, i=2)
def visit_bot(self, *_):
return ast.BOT
def visit_top(self, *_):
return ast.TOP
def visit_interval(self, _, children): def visit_interval(self, _, children):
_, _, (left, ), _, _, _, (right, ), _, _ = children _, _, (left, ), _, _, _, (right, ), _, _ = children
left = left if left != [] else float("inf") left = left if left != [] else float("inf")
@ -114,10 +123,6 @@ class STLVisitor(NodeVisitor):
phi1, _, _, _, phi2 = children phi1, _, _, _, phi2 = children
return ast.Until(phi1, phi2) return ast.Until(phi1, phi2)
def visit_timed_until(self, _, children):
phi, _, _, (lo, hi), _, psi = children
return env(psi, lo=lo, hi=hi) & alw(ast.Until(phi, psi), lo=0, hi=lo)
def visit_id(self, name, _): def visit_id(self, name, _):
return name.text return name.text
@ -148,7 +153,7 @@ class STLVisitor(NodeVisitor):
return ast.AtomicPred(self.visit_id(*args)) return ast.AtomicPred(self.visit_id(*args))
def visit_neg(self, _, children): def visit_neg(self, _, children):
return ast.Neg(children[1]) return ~children[1]
def visit_next(self, _, children): def visit_next(self, _, children):
return ast.Next(children[1]) return ast.Next(children[1])

View file

@ -1,22 +1,47 @@
import unittest
import stl import stl
from stl.hypothesis import SignalTemporalLogicStrategy
from hypothesis import given
class TestSTLAST(unittest.TestCase): @given(SignalTemporalLogicStrategy)
def test_and(self): def test_identities(phi):
phi = stl.parse("x") assert stl.TOP == stl.TOP | phi
self.assertEqual(stl.TOP, stl.TOP | phi) assert stl.BOT == stl.BOT & phi
self.assertEqual(stl.BOT, stl.BOT & phi) assert stl.TOP == phi | stl.TOP
self.assertEqual(stl.TOP, phi | stl.TOP) assert stl.BOT == phi & stl.BOT
self.assertEqual(stl.BOT, phi & stl.BOT) assert phi == phi & stl.TOP
self.assertEqual(phi, phi & stl.TOP) assert phi == phi | stl.BOT
self.assertEqual(phi, phi | stl.BOT) assert stl.TOP == stl.TOP & stl.TOP
self.assertEqual(stl.TOP, stl.TOP & stl.TOP) assert stl.BOT == stl.BOT | stl.BOT
self.assertEqual(stl.BOT, stl.BOT | stl.BOT) assert stl.TOP == stl.TOP | stl.BOT
self.assertEqual(stl.TOP, stl.TOP | stl.BOT) assert stl.BOT == stl.TOP & stl.BOT
self.assertEqual(stl.BOT, stl.TOP & stl.BOT) assert ~stl.BOT == stl.TOP
self.assertEqual(~stl.BOT, stl.TOP) assert ~stl.TOP == stl.BOT
self.assertEqual(~stl.TOP, stl.BOT) assert ~~stl.BOT == stl.BOT
self.assertEqual(~~stl.BOT, stl.BOT) assert ~~stl.TOP == stl.TOP
self.assertEqual(~~stl.TOP, stl.TOP) assert (phi & phi) & phi == phi & (phi & phi)
assert (phi | phi) | phi == phi | (phi | phi)
assert ~~phi == phi
def test_lineqs_unittest():
phi = stl.parse('(G[0, 1](x + y > a?)) & (F[1,2](z - x > 0))')
assert len(phi.lineqs) == 2
assert phi.lineqs == {stl.parse('x + y > a?'), stl.parse('z - x > 0')}
phi = stl.parse('(G[0, 1](x + y > a?)) U (F[1,2](z - x > 0))')
assert len(phi.lineqs) == 2
assert phi.lineqs == {stl.parse('x + y > a?'), stl.parse('z - x > 0')}
phi = stl.parse('G(⊥)')
assert phi.lineqs == set()
phi = stl.parse('F()')
assert phi.lineqs == set()
def test_walk():
phi = stl.parse(
'((G[0, 1](x + y > a?)) & (F[1,2](z - x > 0))) | ((X(AP1)) U (AP2))')
assert len(list((~phi).walk())) == 11

View file

@ -1,11 +1,12 @@
import hypothesis.strategies as st import hypothesis.strategies as st
import traces import traces
from hypothesis import given from hypothesis import given # , settings, Verbosity, Phase
import stl import stl
import stl.boolean_eval import stl.boolean_eval
import stl.fastboolean_eval import stl.fastboolean_eval
# from stl.hypothesis import SignalTemporalLogicStrategy
""" """
TODO: property based test that fasteval should be the same as slow TODO: property based test that fasteval should be the same as slow
@ -33,7 +34,11 @@ x = {
} }
@given(st.just(stl.BOT)) @given(st.just(stl.ast.Next(stl.BOT) | stl.ast.Next(stl.TOP)))
# @given(SignalTemporalLogicStrategy)
# @settings(max_shrinks=0, verbosity=Verbosity.verbose,
# perform_health_check=False,
# phases=[Phase.generate])
def test_boolean_identities(phi): def test_boolean_identities(phi):
stl_eval = stl.boolean_eval.pointwise_sat(phi) stl_eval = stl.boolean_eval.pointwise_sat(phi)
stl_eval2 = stl.boolean_eval.pointwise_sat(~phi) stl_eval2 = stl.boolean_eval.pointwise_sat(~phi)
@ -47,6 +52,24 @@ def test_boolean_identities(phi):
stl_eval6 = stl.boolean_eval.pointwise_sat(phi | ~phi) stl_eval6 = stl.boolean_eval.pointwise_sat(phi | ~phi)
assert stl_eval6(x, 0) assert stl_eval6(x, 0)
# phi2 = stl.alw(stl.ast.Next(phi))
# phi3 = stl.ast.Next(stl.alw(phi))
# stl_eval7 = stl.boolean_eval.pointwise_sat(phi2)
# stl_eval8 = stl.boolean_eval.pointwise_sat(phi3)
# assert stl_eval7(x, 0) == stl_eval8(x, 0)
stl_eval9 = stl.boolean_eval.pointwise_sat(stl.ast.Next(phi))
stl_eval10 = stl.boolean_eval.pointwise_sat(~stl.ast.Next(phi))
assert stl_eval9(x, 0) != stl_eval10(x, 0)
phi4 = stl.parse('~(AP4)')
stl_eval11 = stl.boolean_eval.pointwise_sat(phi4)
assert stl_eval11(x, 0)
phi5 = stl.parse('G[0.1, 0.03](~(AP4))')
stl_eval12 = stl.boolean_eval.pointwise_sat(phi5)
assert stl_eval12(x, 0)
@given(st.just(stl.BOT)) @given(st.just(stl.BOT))
def test_temporal_identities(phi): def test_temporal_identities(phi):

20
stl/test_load.py Normal file
View file

@ -0,0 +1,20 @@
import pandas as pd
from stl.load import from_pandas
DATA = pd.DataFrame(
data={
'AP1': [True, False, True],
'x': [0, 0, 0.1],
'y': [-1, -1, 0],
'z': [2, 3, 1],
},
index=[0, 1, 2],
)
def test_from_pandas():
x = from_pandas(DATA)
assert x['x'][0] == 0
assert x['x'][0.2] == 0
assert not x['AP1'][1.4]

14
stl/test_params.py Normal file
View file

@ -0,0 +1,14 @@
import stl
import hypothesis.strategies as st
from hypothesis import given
@given(st.integers(), st.integers(), st.integers())
def test_params1(a, b, c):
phi = stl.parse('G[a?, b?](x > c?)')
assert {x.name for x in phi.params} == {'a?', 'b?', 'c?'}
phi2 = phi.set_params({'a?': a, 'b?': b, 'c?': c})
assert phi2.params == set()
assert phi2 == stl.parse(f'G[{a}, {b}](x > {c})')

View file

@ -14,3 +14,9 @@ def test_invertable_repr(phi):
@given(SignalTemporalLogicStrategy) @given(SignalTemporalLogicStrategy)
def test_hash_inheritance(phi): def test_hash_inheritance(phi):
assert hash(repr(phi)) == hash(phi) assert hash(repr(phi)) == hash(phi)
def test_sugar_smoke_test():
stl.parse('(x) <-> (x)')
stl.parse('(x) -> (x)')
stl.parse('(x) ^ (x)')

View file

@ -2,7 +2,7 @@ import operator as op
from functools import reduce from functools import reduce
import traces import traces
from lenses import lens, bind from lenses import bind
import stl.ast import stl.ast
from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens) from stl.ast import (And, F, G, Interval, LinEq, Neg, Or, AP_lens)
@ -67,38 +67,37 @@ def get_times(x):
return sorted(times) return sorted(times)
def eval_lineq(lineq, x, compact=True): def const_trace(x):
def eval_term(term, t): oo = float('inf')
return float(term.coeff) * x[term.id.name][t] return traces.TimeSeries([(-oo, x)])
terms = lens(lineq).Each().terms.Each().collect()
def f(t): def eval_lineq(lineq, x, domain, compact=True):
lhs = sum(eval_term(term, t) for term in terms) lhs = sum(const_trace(term.coeff)*x[term.id] for term in lineq.terms)
return op_lookup[lineq.op](lhs, lineq.const) compare = op_lookup.get(lineq.op)
output = lhs.operation(const_trace(lineq.const), compare)
output = traces.TimeSeries(map(f, x), domain=x.domain)
if compact: if compact:
output.compact() output.compact()
return output return output
def eval_lineqs(phi, x, times=None): def eval_lineqs(phi, x):
if times is None:
times = get_times(x)
lineqs = phi.lineqs lineqs = phi.lineqs
return {lineq: eval_lineq(lineq, x, times=times) for lineq in lineqs} start = max(y.domain.start() for y in x.values())
end = min(y.domain.end() for y in x.values())
domain = traces.Domain(start, end)
return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs}
# EDSL # EDSL
def alw(phi, *, lo, hi): def alw(phi, *, lo=0, hi=float('inf')):
return G(Interval(lo, hi), phi) return G(Interval(lo, hi), phi)
def env(phi, *, lo, hi): def env(phi, *, lo=0, hi=float('inf')):
return F(Interval(lo, hi), phi) return F(Interval(lo, hi), phi)
@ -124,3 +123,7 @@ def xor(x, y):
def iff(x, y): def iff(x, y):
return (x & y) | (~x & ~y) return (x & y) | (~x & ~y)
def timed_until(phi, psi, lo, hi):
return env(psi, lo=lo, hi=hi) & alw(stl.ast.Until(phi, psi), lo=0, hi=lo)