yapf + pylint + add style checks to tests

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-26 22:00:03 -07:00
parent d52fffe826
commit a1ca4c6579
18 changed files with 130 additions and 299 deletions

4
.gitignore vendored
View file

@ -1,3 +1,5 @@
dist/* dist/*
.cache/* .cache/*
.hypothesis/* .hypothesis/*
.coverage*
htmlcov/*

View file

@ -2,4 +2,4 @@
python: python:
- "3.6" - "3.6"
install: "pip install -r requirements.txt" install: "pip install -r requirements.txt"
script: pytest script: pytest --flake8 --cov=stl -x --cov-report=html --isort

View file

@ -1,4 +1,5 @@
-e git://github.com/mvcisback/hypothesis-cfg@master#egg=hypothesis-cfg -e git://github.com/mvcisback/hypothesis-cfg@master#egg=hypothesis-cfg
-e git://github.com/mvcisback/multidim-threshold@master#egg=multidim-threshold
bitarray==0.8.1 bitarray==0.8.1
funcy==1.9.1 funcy==1.9.1
lenses==0.3.0 lenses==0.3.0
@ -7,4 +8,16 @@ parsimonious==0.7.0
sympy==1.0 sympy==1.0
traces==0.3.1 traces==0.3.1
pytest==3.2.2 pytest==3.2.2
hypothesis==3.30.3 hypothesis==3.32.1
pytest==3.2.3
pytest-bpdb==0.1.4
pytest-cache==1.0
pytest-colordots==1.1
pytest-cov==2.5.1
pytest-flake8==0.9
pytest-forked==0.2
pytest-isort==0.1.0
pytest-pep8==1.0.6
pytest-xdist==1.20.1
python-dateutil==2.6.1

View file

@ -1,19 +1,20 @@
from setuptools import setup, find_packages from setuptools import find_packages, setup
setup(name='py-stl', setup(
version='0.2', name='py-stl',
description='TODO', version='0.2',
url='http://github.com/mvcisback/py-stl', description='TODO',
author='Marcell Vazquez-Chanlatte', url='http://github.com/mvcisback/py-stl',
author_email='marcell.vc@eecs.berkeley.edu', author='Marcell Vazquez-Chanlatte',
license='MIT', author_email='marcell.vc@eecs.berkeley.edu',
install_requires=[ license='MIT',
'funcy', install_requires=[
'parsimonious', 'funcy',
'lenses', 'parsimonious',
'sympy', 'lenses',
'bitarray', 'sympy',
'traces', 'bitarray',
], 'traces',
packages=find_packages(), ],
packages=find_packages(),
) )

View file

@ -1,9 +1,9 @@
# flake8: noqa
from stl.utils import terms_lens, lineq_lens, walk, and_or_lens from stl.utils import terms_lens, lineq_lens, walk, and_or_lens
from stl.utils import alw, env, andf, orf from stl.utils import alw, env, andf, orf
from stl.ast import dt_sym, t_sym, TOP, BOT from stl.ast import dt_sym, t_sym, TOP, BOT
from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg,
ModalOp, Neg, Var, AtomicPred, Until) Var, AtomicPred, Until)
from stl.parser import parse from stl.parser import parse
from stl.fastboolean_eval import pointwise_sat from stl.fastboolean_eval import pointwise_sat
from stl.synth import lex_param_project
from stl.types import STL from stl.types import STL

View file

@ -1,9 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# TODO: supress + given a + (-b). i.e. want a - b # TODO: supress + given a + (-b). i.e. want a - b
from collections import namedtuple, deque from collections import namedtuple
from itertools import repeat
from enum import Enum
import funcy as fn import funcy as fn
from sympy import Symbol from sympy import Symbol
@ -13,7 +11,9 @@ t_sym = Symbol('t', positive=True)
def flatten_binary(phi, op, dropT, shortT): def flatten_binary(phi, op, dropT, shortT):
f = lambda x: x.args if isinstance(x, op) else [x] def f(x):
return x.args if isinstance(x, op) else [x]
args = [arg for arg in phi.args if arg is not dropT] args = [arg for arg in phi.args if arg is not dropT]
if any(arg is shortT for arg in args): if any(arg is shortT for arg in args):
@ -153,7 +153,7 @@ class And(NaryOpSTL):
class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST): class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST):
__slots__ = () __slots__ = ()
OP = '?' OP = '?'
def __repr__(self): def __repr__(self):
return f"{self.OP}{self.interval}({self.arg})" return f"{self.OP}{self.interval}({self.arg})"
@ -231,7 +231,6 @@ class Param(namedtuple('Param', ['name']), AST):
def __repr__(self): def __repr__(self):
return self.name return self.name
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
return hash(repr(self)) return hash(repr(self))

View file

@ -1,38 +1,42 @@
# TODO: figure out how to deduplicate this with robustness # TODO: figure out how to deduplicate this with robustness
# - Abstract as working on distributive lattice # - Abstract as working on distributive lattice
from functools import singledispatch
import operator as op import operator as op
from functools import singledispatch
import funcy as fn import funcy as fn
from lenses import lens
import stl.ast
import stl import stl
import stl.ast
from lenses import lens
oo = float('inf') oo = float('inf')
def pointwise_sat(phi): def pointwise_sat(phi):
ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()] ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()]
def _eval_stl(x, t): def _eval_stl(x, t):
evaluated = stl.utils.eval_lineqs(phi, x) evaluated = stl.utils.eval_lineqs(phi, x)
evaluated.update(fn.project(x, ap_names)) evaluated.update(fn.project(x, ap_names))
return eval_stl(phi)(evaluated, t) return eval_stl(phi)(evaluated, t)
return _eval_stl return _eval_stl
@singledispatch @singledispatch
def eval_stl(stl): def eval_stl(stl):
raise NotImplementedError raise NotImplementedError
@eval_stl.register(stl.Or) @eval_stl.register(stl.Or)
def _(phi): def eval_stl_or(phi):
fs = [eval_stl(arg) for arg in phi.args] fs = [eval_stl(arg) for arg in phi.args]
return lambda x, t: any(f(x,t) for f in fs) return lambda x, t: any(f(x, t) for f in fs)
@eval_stl.register(stl.And) @eval_stl.register(stl.And)
def _(stl): def eval_stl_and(stl):
fs = [eval_stl(arg) for arg in stl.args] fs = [eval_stl(arg) for arg in stl.args]
return lambda x, t: all(f(x, t) for f in fs) return lambda x, t: all(f(x, t) for f in fs)
@ -57,13 +61,14 @@ def get_times(x, tau, lo=None, hi=None):
@eval_stl.register(stl.Until) @eval_stl.register(stl.Until)
def _(stl): def eval_stl_until(stl):
def _until(x, t): def _until(x, t):
f1, f2 = eval_stl(stl.arg1), eval_stl(stl.arg2) f1, f2 = eval_stl(stl.arg1), eval_stl(stl.arg2)
for tau in get_times(x, t): for tau in get_times(x, t):
if not f1(x, tau): if not f1(x, tau):
return f2(x, tau) return f2(x, tau)
return False return False
return _until return _until
@ -73,24 +78,24 @@ def eval_unary_temporal_op(phi, always=True):
if lo > hi: if lo > hi:
retval = True if always else False retval = True if always else False
return lambda x, t: retval return lambda x, t: retval
f = eval_stl(phi.arg) f = eval_stl(phi.arg)
if hi == lo: if hi == lo:
return lambda x, t: f(x, t) return lambda x, t: f(x, t)
return lambda x, t: fold(f(x, tau) for tau in get_times(x, t, lo, hi)) return lambda x, t: fold(f(x, tau) for tau in get_times(x, t, lo, hi))
@eval_stl.register(stl.F) @eval_stl.register(stl.F)
def _(phi): def eval_stl_f(phi):
return eval_unary_temporal_op(phi, always=False) return eval_unary_temporal_op(phi, always=False)
@eval_stl.register(stl.G) @eval_stl.register(stl.G)
def _(phi): def eval_stl_g(phi):
return eval_unary_temporal_op(phi, always=True) return eval_unary_temporal_op(phi, always=True)
@eval_stl.register(stl.Neg) @eval_stl.register(stl.Neg)
def _(stl): def eval_stl_neg(stl):
f = eval_stl(stl.arg) f = eval_stl(stl.arg)
return lambda x, t: not f(x, t) return lambda x, t: not f(x, t)
@ -105,22 +110,22 @@ op_lookup = {
@eval_stl.register(stl.AtomicPred) @eval_stl.register(stl.AtomicPred)
def _(stl): def eval_stl_ap(stl):
return lambda x, t: x[str(stl.id)][t] return lambda x, t: x[str(stl.id)][t]
@eval_stl.register(type(stl.TOP)) @eval_stl.register(type(stl.TOP))
def _(_): def eval_stl_top(_):
return lambda *_: True return lambda *_: True
@eval_stl.register(type(stl.BOT)) @eval_stl.register(type(stl.BOT))
def _(_): def eval_stl_bot(_):
return lambda *_: False return lambda *_: False
@eval_stl.register(stl.LinEq) @eval_stl.register(stl.LinEq)
def _(lineq): def eval_stl_lineq(lineq):
return lambda x, t: x[lineq][t] return lambda x, t: x[lineq][t]
@ -130,4 +135,4 @@ def eval_terms(lineq, x, t):
def eval_term(term, x, t): def eval_term(term, x, t):
return float(term.coeff)*x[term.id.name][t] return float(term.coeff) * x[term.id][t]

View file

@ -1,19 +1,22 @@
from functools import singledispatch, reduce from functools import reduce, singledispatch
from operator import and_, or_ from operator import and_, or_
from bitarray import bitarray from bitarray import bitarray
import stl.ast import stl.ast
from stl.boolean_eval import eval_terms, op_lookup, get_times from stl.boolean_eval import eval_terms, get_times, op_lookup
def pointwise_sat(stl): def pointwise_sat(stl):
f = pointwise_satf(stl) f = pointwise_satf(stl)
return lambda x, t: bool(int(f(x, [t]).to01())) return lambda x, t: bool(int(f(x, [t]).to01()))
@singledispatch @singledispatch
def pointwise_satf(stl): def pointwise_satf(stl):
raise NotImplementedError raise NotImplementedError
def bool_op(stl, conjunction=False): def bool_op(stl, conjunction=False):
binop = and_ if conjunction else or_ binop = and_ if conjunction else or_
fs = [pointwise_satf(arg) for arg in stl.args] fs = [pointwise_satf(arg) for arg in stl.args]
@ -21,78 +24,59 @@ def bool_op(stl, conjunction=False):
@pointwise_satf.register(stl.Or) @pointwise_satf.register(stl.Or)
def _(stl): def pointwise_satf_or(stl):
return bool_op(stl, conjunction=False) return bool_op(stl, conjunction=False)
@pointwise_satf.register(stl.And) @pointwise_satf.register(stl.And)
def _(stl): def pointwise_satf_and(stl):
return bool_op(stl, conjunction=True) return bool_op(stl, conjunction=True)
def temporal_op(stl, lo, hi, conjunction=False): def temporal_op(stl, lo, hi, conjunction=False):
fold = bitarray.all if conjunction else bitarray.any fold = bitarray.all if conjunction else bitarray.any
f = pointwise_satf(stl.arg) f = pointwise_satf(stl.arg)
def sat_comp(x,t):
def sat_comp(x, t):
return bitarray(fold(f(x, get_times(x, tau, lo, hi))) for tau in t) return bitarray(fold(f(x, get_times(x, tau, lo, hi))) for tau in t)
return sat_comp return sat_comp
@pointwise_satf.register(stl.Until)
def _(stl):
f1, f2 = pointwise_satf(stl.arg1), pointwise_satf(stl.arg2)
def __until(x, t):
f1, f2 = pointwise_satf(stl.arg1), pointwise_satf(stl.arg2)
state = False
times = get_times(x, t[0])
for phi, tau in zip(reversed(f1(x, times)), reversed(times)):
if not phi:
state = f2(x, [tau])
if tau in t:
yield state
def _until(x, t):
retval = bitarray(__until(x, t))
retval.reverse()
return retval
return _until
@pointwise_satf.register(stl.F) @pointwise_satf.register(stl.F)
def _(stl): def pointwise_satf_f(stl):
lo, hi = stl.interval lo, hi = stl.interval
return temporal_op(stl, lo, hi, conjunction=False) return temporal_op(stl, lo, hi, conjunction=False)
@pointwise_satf.register(stl.G) @pointwise_satf.register(stl.G)
def _(stl): def pointwise_satf_g(stl):
lo, hi = stl.interval lo, hi = stl.interval
return temporal_op(stl, lo, hi, conjunction=True) return temporal_op(stl, lo, hi, conjunction=True)
@pointwise_satf.register(stl.Neg) @pointwise_satf.register(stl.Neg)
def _(stl): def pointwise_satf_neg(stl):
return lambda x,t: ~pointwise_satf(stl.arg)(x, t) return lambda x, t: ~pointwise_satf(stl.arg)(x, t)
@pointwise_satf.register(stl.AtomicPred) @pointwise_satf.register(stl.AtomicPred)
def _(phi): def pointwise_satf_(phi):
return lambda x, t: bitarray(x[str(phi.id)][tau] for tau in t) return lambda x, t: bitarray(x[str(phi.id)][tau] for tau in t)
@pointwise_satf.register(type(stl.TOP)) @pointwise_satf.register(type(stl.TOP))
def _(_): def pointwise_satf_top(_):
return lambda _, t: bitarray([True]*len(t)) return lambda _, t: bitarray([True] * len(t))
@pointwise_satf.register(type(stl.BOT)) @pointwise_satf.register(type(stl.BOT))
def _(_): def pointwise_satf_bot(_):
return lambda _, t: bitarray([False]*len(t)) return lambda _, t: bitarray([False] * len(t))
@pointwise_satf.register(stl.LinEq) @pointwise_satf.register(stl.LinEq)
def _(stl): def pointwise_satf_lineq(stl):
op = lambda a: op_lookup[stl.op](a, stl.const) def op(a):
return op_lookup[stl.op](a, stl.const)
return lambda x, t: bitarray(op(eval_terms(stl, x, tau)) for tau in t) return lambda x, t: bitarray(op(eval_terms(stl, x, tau)) for tau in t)

View file

@ -1,23 +1,23 @@
from hypothesis_cfg import ContextFreeGrammarStrategy
import hypothesis.strategies as st import hypothesis.strategies as st
from hypothesis.searchstrategy.strategies import SearchStrategy from hypothesis_cfg import ContextFreeGrammarStrategy
import stl import stl
GRAMMAR = { GRAMMAR = {
'phi': (('Unary', '(', 'phi', ')'), 'phi': (('Unary', '(', 'phi', ')'), ('(', 'phi', ')', 'Binary', '(', 'phi',
('(', 'phi', ')', 'Binary', '(', 'phi', ')'), ')'), ('AP', )),
('AP',)), 'Unary': (('~', ), ('G', ), ('F', ), ('X', )),
'Unary': (('~',), ('G',), ('F',), ('X',)), 'Binary': ((' | ', ), (' & ', ), (' U ', )),
'Binary': ((' | ',), (' & ',), (' U ',)), 'AP': (('AP1', ), ('AP2', ), ('AP3', ), ('AP4', ), ('AP5', )),
'AP': (('AP1',), ('AP2',), ('AP3',), ('AP4',), ('AP5',)),
} }
def to_stl(term): def to_stl(term):
return stl.parse(''.join(term)) return stl.parse(''.join(term))
SignalTemporalLogicStrategy = st.builds(
to_stl, ContextFreeGrammarStrategy( SignalTemporalLogicStrategy = st.builds(to_stl,
GRAMMAR, max_length=25, start='phi')) ContextFreeGrammarStrategy(
GRAMMAR,
max_length=25,
start='phi'))

View file

@ -1,4 +1,5 @@
from traces import TimeSeries, Domain from traces import Domain, TimeSeries
def from_pandas(df, compact=True): def from_pandas(df, compact=True):
'''TODO''' '''TODO'''

View file

@ -5,15 +5,11 @@
# TODO: Allow -x = -1*x # TODO: Allow -x = -1*x
from functools import partialmethod from functools import partialmethod
from collections import namedtuple
import operator as op
from parsimonious import Grammar, NodeVisitor
from funcy import flatten
from lenses import lens from lenses import lens
from parsimonious import Grammar, NodeVisitor
from stl import ast from stl import ast
from stl.utils import implies, xor, iff, env, alw from stl.utils import alw, env, iff, implies, xor
STL_GRAMMAR = Grammar(u''' STL_GRAMMAR = Grammar(u'''
phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and
@ -60,9 +56,9 @@ __ = ~r"\s"*
EOL = "\\n" EOL = "\\n"
''') ''')
oo = float('inf') oo = float('inf')
class STLVisitor(NodeVisitor): class STLVisitor(NodeVisitor):
def __init__(self, H=oo): def __init__(self, H=oo):
super().__init__() super().__init__()
@ -137,7 +133,6 @@ class STLVisitor(NodeVisitor):
c = coeffs[0] if coeffs else 1 c = coeffs[0] if coeffs else 1
return ast.Var(coeff=c, id=iden) return ast.Var(coeff=c, id=iden)
def visit_terms(self, _, children): def visit_terms(self, _, children):
if isinstance(children[0], list): if isinstance(children[0], list):
term, _1, sgn, _2, terms = children[0] term, _1, sgn, _2, terms = children[0]

View file

@ -1,38 +0,0 @@
import operator as op
from stl.utils import set_params, param_lens
from stl.boolean_eval import pointwise_sat
from lenses import lens
def binsearch(stleval, *, tol=1e-3, lo, hi, polarity):
"""Only run search if tightest robustness was positive."""
# Early termination via bounds checks
if polarity and stleval(lo):
return lo
elif not polarity and stleval(hi):
return hi
while hi - lo > tol:
mid = lo + (hi - lo) / 2
r = stleval(mid)
lo, hi = (mid, hi) if r ^ polarity else (lo, mid)
# Want satisifiable formula
return hi if polarity else lo
def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3):
val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order}
p_lens = param_lens(stl)
def stleval_fact(var, val):
l = lens(val)[var]
return lambda p: pointwise_sat(set_params(stl, l.set(p)))(x, 0)
for var in order:
stleval = stleval_fact(var, val)
lo, hi = ranges[var]
param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var])
val[var] = param
return val

View file

@ -1,6 +1,8 @@
import stl
import unittest import unittest
import stl
class TestSTLAST(unittest.TestCase): class TestSTLAST(unittest.TestCase):
def test_and(self): def test_and(self):
phi = stl.parse("x") phi = stl.parse("x")

View file

@ -1,13 +1,11 @@
import hypothesis.strategies as st
import traces
from hypothesis import given
import stl import stl
import stl.boolean_eval import stl.boolean_eval
import stl.fastboolean_eval import stl.fastboolean_eval
import traces
import unittest
from sympy import Symbol
import hypothesis.strategies as st
from hypothesis import given, note, assume, example
""" """
TODO: property based test that fasteval should be the same as slow TODO: property based test that fasteval should be the same as slow
@ -28,7 +26,11 @@ x = {
"A": traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)]), "A": traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)]),
"B": traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)]), "B": traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)]),
"C": traces.TimeSeries([(0, True), (0.1, True), (0.2, False)]), "C": traces.TimeSeries([(0, True), (0.1, True), (0.2, False)]),
'D': traces.TimeSeries({0.0: 2, 13.8: 3, 19.7: 2}), 'D': traces.TimeSeries({
0.0: 2,
13.8: 3,
19.7: 2
}),
} }
@ -56,4 +58,3 @@ def test_temporal_identities(phi):
stl_eval3 = stl.fastboolean_eval.pointwise_sat(~stl.alw(~phi, lo=0, hi=4)) stl_eval3 = stl.fastboolean_eval.pointwise_sat(~stl.alw(~phi, lo=0, hi=4))
stl_eval4 = stl.fastboolean_eval.pointwise_sat(stl.env(phi, lo=0, hi=4)) stl_eval4 = stl.fastboolean_eval.pointwise_sat(stl.env(phi, lo=0, hi=4))
assert stl_eval4(x, 0) == stl_eval3(x, 0) assert stl_eval4(x, 0) == stl_eval3(x, 0)

View file

@ -1,7 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import stl from hypothesis import event, given
from hypothesis import given, event
import stl
from stl.hypothesis import SignalTemporalLogicStrategy from stl.hypothesis import SignalTemporalLogicStrategy

View file

@ -1,38 +0,0 @@
import stl
import stl.synth
import traces
import unittest
from sympy import Symbol
oo = float('inf')
ex1 = ("A > a?", ("a?",), {"a?": (0, 10)}, {"a?": False}, {"a?": 1})
ex2 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 5)},
{"a?": False, "b?": True}, {"a?": 4, "b?": 0.2})
ex3 = ("F[0, b?](A < 0)", ("b?",), {"b?": (0, 5)},
{"b?": True}, {"b?": 5})
ex4 = ("G[0, b?](A < 0)", ("b?",), {"b?": (0.1, 5)},
{"b?": False}, {"b?": 0.1})
ex5 = ("F[0, b?](A > 0)", ("b?",), {"b?": (0.1, 5)},
{"b?": True}, {"b?": 0.1})
ex6 = ("(A > a?) or (A > b?)", ("a?", "b?",), {"a?": (0, 2), "b?": (0, 2)},
{"a?": False, "b?": False}, {"a?": 2, "b?": 1})
x = {
"A": traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)]),
"B": traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)]),
"C": traces.TimeSeries([(0, True), (0.1, True), (0.2, False)]),
}
"""
class TestSTLRobustness(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6)
def test_lex_synth(self, phi_str, order, ranges, polarity, val):
phi = stl.parse(phi_str)
val2 = stl.synth.lex_param_project(
phi, x, order=order, ranges=ranges, polarity=polarity)
# check that the valuations are almost the same
for var in order:
self.assertAlmostEqual(val2[var], val[var], delta=0.01)
"""

View file

@ -1,94 +0,0 @@
import stl
import stl.utils
import pandas as pd
import unittest
from sympy import Symbol
ex1 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"})
ex2 = ("G[0, c?](x > a?)", {"a?", "c?"})
ex3 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"})
ex4 = ("F[b?, 1]G[0, c?](x > a?)", "F[2, 1]G[0, 3](x > 1)")
ex5 = ("G[0, c?](x > a?)", "G[0, 3](x > 1)")
val = {"a?": 1.0, "b?": 2.0, "c?": 3.0}
"""
class TestSTLUtils(unittest.TestCase):
@params(ex1, ex2, ex3)
def test_param_lens(self, phi_str, params):
phi = stl.parse(phi_str)
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), params)
@params(ex4, ex5)
def test_set_params(self, phi_str, phi2_str):
phi = stl.parse(phi_str)
phi2 = stl.parse(phi2_str)
phi = stl.utils.set_params(phi, val)
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set())
self.assertEqual(phi, phi2)
@params(("x > 5", 1), ("~(x)", 2), ("(F[0,1](x)) & (~(G[0, 2](y)))", 6))
def test_walk(self, phi_str, l):
self.assertEqual(l, len(list(stl.walk(stl.parse(phi_str)))))
@params(([], False, False),([int], True, False), ([int, bool], True, True))
def test_type_pred(self, types, b1, b2):
pred = stl.utils.type_pred(*types)
self.assertFalse(pred(None))
self.assertEqual(pred(1), b1)
self.assertEqual(pred(True), b2)
@params(("(F[0,1]G[0, 4]((x > 3) or (y < 4))) and (x < 3)", 2))
def test_vars_in_phi(self, phi_str, l):
phi = stl.parse(phi_str)
self.assertEqual(len(stl.utils.vars_in_phi(phi)), l)
@params(("(F[0,1]G[0, 4]((x > 3) or (y < 4))) and (x < 3)", 3))
def test_terms_lens(self, phi_str, l):
phi = stl.parse(phi_str)
l2 = len(stl.terms_lens(phi).get_all())
self.assertEqual(l, l2)
@params(("(F[0,1]G[0, 4]((x > 3) | (y < 4))) & (x < 3)", 7, 12))
def test_f_neg_or_canonical_form(self, phi_str, pre_l, post_l):
phi = stl.parse(phi_str)
pre_l2 = len(list(stl.walk(phi)))
self.assertEqual(pre_l, pre_l2)
post_l2 = len(list(stl.walk(stl.utils.f_neg_or_canonical_form(phi))))
self.assertEqual(post_l, post_l2)
def test_andf(self):
phi = stl.parse("x")
self.assertEqual(phi, stl.andf(phi))
def test_orf(self):
phi = stl.parse("x")
self.assertEqual(phi, stl.orf(phi))
def test_inline_context(self):
context = {
stl.parse("x"): stl.parse("(z) & (y)"),
stl.parse("z"): stl.parse("y - x > 4")
}
context2 = {
stl.parse("x"): stl.parse("x"),
}
phi = stl.parse("x")
self.assertEqual(stl.utils.inline_context(phi, {}), phi)
self.assertEqual(stl.utils.inline_context(phi, context),
stl.parse("(y - x > 4) & (y)"))
phi2 = stl.parse("((x) & (z)) | (y)")
self.assertEqual(stl.utils.inline_context(phi2, context),
stl.parse("((y - x > 4) & (y) & (y - x > 4)) | (y)"))
# def test_to_from_mtl(self):
# raise NotImplementedError
# def test_get_polarity(self):
# raise NotImplementedError
# def test_canonical_polarity(self):
# raise NotImplementedError
"""

View file

@ -1,18 +1,18 @@
from typing import List, Type, Dict, Mapping, T, TypeVar
from collections import deque
import operator as op import operator as op
from collections import deque
from functools import reduce from functools import reduce
from typing import Dict, List, Mapping, T, Type, TypeVar
import lenses
from lenses import lens
import funcy as fn import funcy as fn
import sympy import sympy
import traces import traces
import lenses
import stl.ast import stl.ast
from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg, from lenses import lens
AtomicPred, Param, AST) from stl.ast import (AST, And, AtomicPred, F, G, Interval, LinEq, NaryOpSTL,
from stl.types import STL, STL_Generator, MTL Neg, Or, Param)
from stl.types import MTL, STL, STL_Generator
Lens = TypeVar('Lens') Lens = TypeVar('Lens')
@ -37,7 +37,7 @@ def type_pred(*args: List[Type]) -> Mapping[Type, bool]:
return lambda x: type(x) in ast_types return lambda x: type(x) in ast_types
def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None, def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None,
getter=False) -> Lens: getter=False) -> Lens:
if focus_lens is None: if focus_lens is None:
focus_lens = lambda _: [lens] focus_lens = lambda _: [lens]
@ -86,9 +86,7 @@ def param_lens(phi: STL) -> Lens:
] ]
return (x for x in candidates if isinstance(x.get()(leaf), Param)) return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens( return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens)
phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens)
def set_params(phi, val) -> STL: def set_params(phi, val) -> STL: