yapf + pylint + add style checks to tests
This commit is contained in:
parent
d52fffe826
commit
a1ca4c6579
18 changed files with 130 additions and 299 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
|
@ -1,3 +1,5 @@
|
|||
dist/*
|
||||
.cache/*
|
||||
.hypothesis/*
|
||||
.coverage*
|
||||
htmlcov/*
|
||||
|
|
@ -2,4 +2,4 @@
|
|||
python:
|
||||
- "3.6"
|
||||
install: "pip install -r requirements.txt"
|
||||
script: pytest
|
||||
script: pytest --flake8 --cov=stl -x --cov-report=html --isort
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
-e git://github.com/mvcisback/hypothesis-cfg@master#egg=hypothesis-cfg
|
||||
-e git://github.com/mvcisback/multidim-threshold@master#egg=multidim-threshold
|
||||
bitarray==0.8.1
|
||||
funcy==1.9.1
|
||||
lenses==0.3.0
|
||||
|
|
@ -7,4 +8,16 @@ parsimonious==0.7.0
|
|||
sympy==1.0
|
||||
traces==0.3.1
|
||||
pytest==3.2.2
|
||||
hypothesis==3.30.3
|
||||
hypothesis==3.32.1
|
||||
|
||||
pytest==3.2.3
|
||||
pytest-bpdb==0.1.4
|
||||
pytest-cache==1.0
|
||||
pytest-colordots==1.1
|
||||
pytest-cov==2.5.1
|
||||
pytest-flake8==0.9
|
||||
pytest-forked==0.2
|
||||
pytest-isort==0.1.0
|
||||
pytest-pep8==1.0.6
|
||||
pytest-xdist==1.20.1
|
||||
python-dateutil==2.6.1
|
||||
|
|
|
|||
5
setup.py
5
setup.py
|
|
@ -1,6 +1,7 @@
|
|||
from setuptools import setup, find_packages
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(name='py-stl',
|
||||
setup(
|
||||
name='py-stl',
|
||||
version='0.2',
|
||||
description='TODO',
|
||||
url='http://github.com/mvcisback/py-stl',
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
# flake8: noqa
|
||||
from stl.utils import terms_lens, lineq_lens, walk, and_or_lens
|
||||
from stl.utils import alw, env, andf, orf
|
||||
from stl.ast import dt_sym, t_sym, TOP, BOT
|
||||
from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G,
|
||||
ModalOp, Neg, Var, AtomicPred, Until)
|
||||
from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg,
|
||||
Var, AtomicPred, Until)
|
||||
from stl.parser import parse
|
||||
from stl.fastboolean_eval import pointwise_sat
|
||||
from stl.synth import lex_param_project
|
||||
from stl.types import STL
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# TODO: supress + given a + (-b). i.e. want a - b
|
||||
|
||||
from collections import namedtuple, deque
|
||||
from itertools import repeat
|
||||
from enum import Enum
|
||||
from collections import namedtuple
|
||||
|
||||
import funcy as fn
|
||||
from sympy import Symbol
|
||||
|
|
@ -13,7 +11,9 @@ t_sym = Symbol('t', positive=True)
|
|||
|
||||
|
||||
def flatten_binary(phi, op, dropT, shortT):
|
||||
f = lambda x: x.args if isinstance(x, op) else [x]
|
||||
def f(x):
|
||||
return x.args if isinstance(x, op) else [x]
|
||||
|
||||
args = [arg for arg in phi.args if arg is not dropT]
|
||||
|
||||
if any(arg is shortT for arg in args):
|
||||
|
|
@ -231,7 +231,6 @@ class Param(namedtuple('Param', ['name']), AST):
|
|||
def __repr__(self):
|
||||
return self.name
|
||||
|
||||
|
||||
def __hash__(self):
|
||||
# TODO: compute hash based on contents
|
||||
return hash(repr(self))
|
||||
|
|
|
|||
|
|
@ -1,38 +1,42 @@
|
|||
# TODO: figure out how to deduplicate this with robustness
|
||||
# - Abstract as working on distributive lattice
|
||||
|
||||
from functools import singledispatch
|
||||
import operator as op
|
||||
from functools import singledispatch
|
||||
|
||||
import funcy as fn
|
||||
from lenses import lens
|
||||
|
||||
import stl.ast
|
||||
import stl
|
||||
import stl.ast
|
||||
from lenses import lens
|
||||
|
||||
oo = float('inf')
|
||||
|
||||
|
||||
def pointwise_sat(phi):
|
||||
ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()]
|
||||
|
||||
def _eval_stl(x, t):
|
||||
evaluated = stl.utils.eval_lineqs(phi, x)
|
||||
evaluated.update(fn.project(x, ap_names))
|
||||
return eval_stl(phi)(evaluated, t)
|
||||
|
||||
return _eval_stl
|
||||
|
||||
|
||||
@singledispatch
|
||||
def eval_stl(stl):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@eval_stl.register(stl.Or)
|
||||
def _(phi):
|
||||
def eval_stl_or(phi):
|
||||
fs = [eval_stl(arg) for arg in phi.args]
|
||||
return lambda x, t: any(f(x,t) for f in fs)
|
||||
return lambda x, t: any(f(x, t) for f in fs)
|
||||
|
||||
|
||||
@eval_stl.register(stl.And)
|
||||
def _(stl):
|
||||
def eval_stl_and(stl):
|
||||
fs = [eval_stl(arg) for arg in stl.args]
|
||||
return lambda x, t: all(f(x, t) for f in fs)
|
||||
|
||||
|
|
@ -57,13 +61,14 @@ def get_times(x, tau, lo=None, hi=None):
|
|||
|
||||
|
||||
@eval_stl.register(stl.Until)
|
||||
def _(stl):
|
||||
def eval_stl_until(stl):
|
||||
def _until(x, t):
|
||||
f1, f2 = eval_stl(stl.arg1), eval_stl(stl.arg2)
|
||||
for tau in get_times(x, t):
|
||||
if not f1(x, tau):
|
||||
return f2(x, tau)
|
||||
return False
|
||||
|
||||
return _until
|
||||
|
||||
|
||||
|
|
@ -80,17 +85,17 @@ def eval_unary_temporal_op(phi, always=True):
|
|||
|
||||
|
||||
@eval_stl.register(stl.F)
|
||||
def _(phi):
|
||||
def eval_stl_f(phi):
|
||||
return eval_unary_temporal_op(phi, always=False)
|
||||
|
||||
|
||||
@eval_stl.register(stl.G)
|
||||
def _(phi):
|
||||
def eval_stl_g(phi):
|
||||
return eval_unary_temporal_op(phi, always=True)
|
||||
|
||||
|
||||
@eval_stl.register(stl.Neg)
|
||||
def _(stl):
|
||||
def eval_stl_neg(stl):
|
||||
f = eval_stl(stl.arg)
|
||||
return lambda x, t: not f(x, t)
|
||||
|
||||
|
|
@ -105,22 +110,22 @@ op_lookup = {
|
|||
|
||||
|
||||
@eval_stl.register(stl.AtomicPred)
|
||||
def _(stl):
|
||||
def eval_stl_ap(stl):
|
||||
return lambda x, t: x[str(stl.id)][t]
|
||||
|
||||
|
||||
@eval_stl.register(type(stl.TOP))
|
||||
def _(_):
|
||||
def eval_stl_top(_):
|
||||
return lambda *_: True
|
||||
|
||||
|
||||
@eval_stl.register(type(stl.BOT))
|
||||
def _(_):
|
||||
def eval_stl_bot(_):
|
||||
return lambda *_: False
|
||||
|
||||
|
||||
@eval_stl.register(stl.LinEq)
|
||||
def _(lineq):
|
||||
def eval_stl_lineq(lineq):
|
||||
return lambda x, t: x[lineq][t]
|
||||
|
||||
|
||||
|
|
@ -130,4 +135,4 @@ def eval_terms(lineq, x, t):
|
|||
|
||||
|
||||
def eval_term(term, x, t):
|
||||
return float(term.coeff)*x[term.id.name][t]
|
||||
return float(term.coeff) * x[term.id][t]
|
||||
|
|
|
|||
|
|
@ -1,19 +1,22 @@
|
|||
from functools import singledispatch, reduce
|
||||
from functools import reduce, singledispatch
|
||||
from operator import and_, or_
|
||||
|
||||
from bitarray import bitarray
|
||||
|
||||
import stl.ast
|
||||
from stl.boolean_eval import eval_terms, op_lookup, get_times
|
||||
from stl.boolean_eval import eval_terms, get_times, op_lookup
|
||||
|
||||
|
||||
def pointwise_sat(stl):
|
||||
f = pointwise_satf(stl)
|
||||
return lambda x, t: bool(int(f(x, [t]).to01()))
|
||||
|
||||
|
||||
@singledispatch
|
||||
def pointwise_satf(stl):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def bool_op(stl, conjunction=False):
|
||||
binop = and_ if conjunction else or_
|
||||
fs = [pointwise_satf(arg) for arg in stl.args]
|
||||
|
|
@ -21,78 +24,59 @@ def bool_op(stl, conjunction=False):
|
|||
|
||||
|
||||
@pointwise_satf.register(stl.Or)
|
||||
def _(stl):
|
||||
def pointwise_satf_or(stl):
|
||||
return bool_op(stl, conjunction=False)
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.And)
|
||||
def _(stl):
|
||||
def pointwise_satf_and(stl):
|
||||
return bool_op(stl, conjunction=True)
|
||||
|
||||
|
||||
def temporal_op(stl, lo, hi, conjunction=False):
|
||||
fold = bitarray.all if conjunction else bitarray.any
|
||||
f = pointwise_satf(stl.arg)
|
||||
def sat_comp(x,t):
|
||||
|
||||
def sat_comp(x, t):
|
||||
return bitarray(fold(f(x, get_times(x, tau, lo, hi))) for tau in t)
|
||||
|
||||
return sat_comp
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.Until)
|
||||
def _(stl):
|
||||
f1, f2 = pointwise_satf(stl.arg1), pointwise_satf(stl.arg2)
|
||||
def __until(x, t):
|
||||
f1, f2 = pointwise_satf(stl.arg1), pointwise_satf(stl.arg2)
|
||||
|
||||
state = False
|
||||
times = get_times(x, t[0])
|
||||
for phi, tau in zip(reversed(f1(x, times)), reversed(times)):
|
||||
if not phi:
|
||||
state = f2(x, [tau])
|
||||
|
||||
if tau in t:
|
||||
yield state
|
||||
|
||||
def _until(x, t):
|
||||
retval = bitarray(__until(x, t))
|
||||
retval.reverse()
|
||||
return retval
|
||||
|
||||
return _until
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.F)
|
||||
def _(stl):
|
||||
def pointwise_satf_f(stl):
|
||||
lo, hi = stl.interval
|
||||
return temporal_op(stl, lo, hi, conjunction=False)
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.G)
|
||||
def _(stl):
|
||||
def pointwise_satf_g(stl):
|
||||
lo, hi = stl.interval
|
||||
return temporal_op(stl, lo, hi, conjunction=True)
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.Neg)
|
||||
def _(stl):
|
||||
return lambda x,t: ~pointwise_satf(stl.arg)(x, t)
|
||||
def pointwise_satf_neg(stl):
|
||||
return lambda x, t: ~pointwise_satf(stl.arg)(x, t)
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.AtomicPred)
|
||||
def _(phi):
|
||||
def pointwise_satf_(phi):
|
||||
return lambda x, t: bitarray(x[str(phi.id)][tau] for tau in t)
|
||||
|
||||
|
||||
@pointwise_satf.register(type(stl.TOP))
|
||||
def _(_):
|
||||
return lambda _, t: bitarray([True]*len(t))
|
||||
def pointwise_satf_top(_):
|
||||
return lambda _, t: bitarray([True] * len(t))
|
||||
|
||||
|
||||
@pointwise_satf.register(type(stl.BOT))
|
||||
def _(_):
|
||||
return lambda _, t: bitarray([False]*len(t))
|
||||
def pointwise_satf_bot(_):
|
||||
return lambda _, t: bitarray([False] * len(t))
|
||||
|
||||
|
||||
@pointwise_satf.register(stl.LinEq)
|
||||
def _(stl):
|
||||
op = lambda a: op_lookup[stl.op](a, stl.const)
|
||||
def pointwise_satf_lineq(stl):
|
||||
def op(a):
|
||||
return op_lookup[stl.op](a, stl.const)
|
||||
return lambda x, t: bitarray(op(eval_terms(stl, x, tau)) for tau in t)
|
||||
|
|
|
|||
|
|
@ -1,23 +1,23 @@
|
|||
from hypothesis_cfg import ContextFreeGrammarStrategy
|
||||
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis.searchstrategy.strategies import SearchStrategy
|
||||
from hypothesis_cfg import ContextFreeGrammarStrategy
|
||||
|
||||
import stl
|
||||
|
||||
|
||||
GRAMMAR = {
|
||||
'phi': (('Unary', '(', 'phi', ')'),
|
||||
('(', 'phi', ')', 'Binary', '(', 'phi', ')'),
|
||||
('AP',)),
|
||||
'Unary': (('~',), ('G',), ('F',), ('X',)),
|
||||
'Binary': ((' | ',), (' & ',), (' U ',)),
|
||||
'AP': (('AP1',), ('AP2',), ('AP3',), ('AP4',), ('AP5',)),
|
||||
'phi': (('Unary', '(', 'phi', ')'), ('(', 'phi', ')', 'Binary', '(', 'phi',
|
||||
')'), ('AP', )),
|
||||
'Unary': (('~', ), ('G', ), ('F', ), ('X', )),
|
||||
'Binary': ((' | ', ), (' & ', ), (' U ', )),
|
||||
'AP': (('AP1', ), ('AP2', ), ('AP3', ), ('AP4', ), ('AP5', )),
|
||||
}
|
||||
|
||||
|
||||
def to_stl(term):
|
||||
return stl.parse(''.join(term))
|
||||
|
||||
SignalTemporalLogicStrategy = st.builds(
|
||||
to_stl, ContextFreeGrammarStrategy(
|
||||
GRAMMAR, max_length=25, start='phi'))
|
||||
|
||||
SignalTemporalLogicStrategy = st.builds(to_stl,
|
||||
ContextFreeGrammarStrategy(
|
||||
GRAMMAR,
|
||||
max_length=25,
|
||||
start='phi'))
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from traces import TimeSeries, Domain
|
||||
from traces import Domain, TimeSeries
|
||||
|
||||
|
||||
def from_pandas(df, compact=True):
|
||||
'''TODO'''
|
||||
|
|
|
|||
|
|
@ -5,15 +5,11 @@
|
|||
# TODO: Allow -x = -1*x
|
||||
|
||||
from functools import partialmethod
|
||||
from collections import namedtuple
|
||||
import operator as op
|
||||
|
||||
from parsimonious import Grammar, NodeVisitor
|
||||
from funcy import flatten
|
||||
from lenses import lens
|
||||
|
||||
from parsimonious import Grammar, NodeVisitor
|
||||
from stl import ast
|
||||
from stl.utils import implies, xor, iff, env, alw
|
||||
from stl.utils import alw, env, iff, implies, xor
|
||||
|
||||
STL_GRAMMAR = Grammar(u'''
|
||||
phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and
|
||||
|
|
@ -60,9 +56,9 @@ __ = ~r"\s"*
|
|||
EOL = "\\n"
|
||||
''')
|
||||
|
||||
|
||||
oo = float('inf')
|
||||
|
||||
|
||||
class STLVisitor(NodeVisitor):
|
||||
def __init__(self, H=oo):
|
||||
super().__init__()
|
||||
|
|
@ -137,7 +133,6 @@ class STLVisitor(NodeVisitor):
|
|||
c = coeffs[0] if coeffs else 1
|
||||
return ast.Var(coeff=c, id=iden)
|
||||
|
||||
|
||||
def visit_terms(self, _, children):
|
||||
if isinstance(children[0], list):
|
||||
term, _1, sgn, _2, terms = children[0]
|
||||
|
|
|
|||
38
stl/synth.py
38
stl/synth.py
|
|
@ -1,38 +0,0 @@
|
|||
import operator as op
|
||||
|
||||
from stl.utils import set_params, param_lens
|
||||
from stl.boolean_eval import pointwise_sat
|
||||
|
||||
from lenses import lens
|
||||
|
||||
def binsearch(stleval, *, tol=1e-3, lo, hi, polarity):
|
||||
"""Only run search if tightest robustness was positive."""
|
||||
# Early termination via bounds checks
|
||||
if polarity and stleval(lo):
|
||||
return lo
|
||||
elif not polarity and stleval(hi):
|
||||
return hi
|
||||
|
||||
while hi - lo > tol:
|
||||
mid = lo + (hi - lo) / 2
|
||||
r = stleval(mid)
|
||||
lo, hi = (mid, hi) if r ^ polarity else (lo, mid)
|
||||
|
||||
# Want satisifiable formula
|
||||
return hi if polarity else lo
|
||||
|
||||
|
||||
def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3):
|
||||
val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order}
|
||||
p_lens = param_lens(stl)
|
||||
def stleval_fact(var, val):
|
||||
l = lens(val)[var]
|
||||
return lambda p: pointwise_sat(set_params(stl, l.set(p)))(x, 0)
|
||||
|
||||
for var in order:
|
||||
stleval = stleval_fact(var, val)
|
||||
lo, hi = ranges[var]
|
||||
param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var])
|
||||
val[var] = param
|
||||
|
||||
return val
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
import stl
|
||||
import unittest
|
||||
|
||||
import stl
|
||||
|
||||
|
||||
class TestSTLAST(unittest.TestCase):
|
||||
def test_and(self):
|
||||
phi = stl.parse("x")
|
||||
|
|
|
|||
|
|
@ -1,13 +1,11 @@
|
|||
import hypothesis.strategies as st
|
||||
import traces
|
||||
|
||||
from hypothesis import given
|
||||
|
||||
import stl
|
||||
import stl.boolean_eval
|
||||
import stl.fastboolean_eval
|
||||
import traces
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given, note, assume, example
|
||||
|
||||
|
||||
"""
|
||||
TODO: property based test that fasteval should be the same as slow
|
||||
|
|
@ -28,7 +26,11 @@ x = {
|
|||
"A": traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)]),
|
||||
"B": traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)]),
|
||||
"C": traces.TimeSeries([(0, True), (0.1, True), (0.2, False)]),
|
||||
'D': traces.TimeSeries({0.0: 2, 13.8: 3, 19.7: 2}),
|
||||
'D': traces.TimeSeries({
|
||||
0.0: 2,
|
||||
13.8: 3,
|
||||
19.7: 2
|
||||
}),
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -56,4 +58,3 @@ def test_temporal_identities(phi):
|
|||
stl_eval3 = stl.fastboolean_eval.pointwise_sat(~stl.alw(~phi, lo=0, hi=4))
|
||||
stl_eval4 = stl.fastboolean_eval.pointwise_sat(stl.env(phi, lo=0, hi=4))
|
||||
assert stl_eval4(x, 0) == stl_eval3(x, 0)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import stl
|
||||
from hypothesis import given, event
|
||||
from hypothesis import event, given
|
||||
|
||||
import stl
|
||||
from stl.hypothesis import SignalTemporalLogicStrategy
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,38 +0,0 @@
|
|||
import stl
|
||||
import stl.synth
|
||||
import traces
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
oo = float('inf')
|
||||
|
||||
ex1 = ("A > a?", ("a?",), {"a?": (0, 10)}, {"a?": False}, {"a?": 1})
|
||||
ex2 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 5)},
|
||||
{"a?": False, "b?": True}, {"a?": 4, "b?": 0.2})
|
||||
ex3 = ("F[0, b?](A < 0)", ("b?",), {"b?": (0, 5)},
|
||||
{"b?": True}, {"b?": 5})
|
||||
ex4 = ("G[0, b?](A < 0)", ("b?",), {"b?": (0.1, 5)},
|
||||
{"b?": False}, {"b?": 0.1})
|
||||
ex5 = ("F[0, b?](A > 0)", ("b?",), {"b?": (0.1, 5)},
|
||||
{"b?": True}, {"b?": 0.1})
|
||||
ex6 = ("(A > a?) or (A > b?)", ("a?", "b?",), {"a?": (0, 2), "b?": (0, 2)},
|
||||
{"a?": False, "b?": False}, {"a?": 2, "b?": 1})
|
||||
|
||||
x = {
|
||||
"A": traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)]),
|
||||
"B": traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)]),
|
||||
"C": traces.TimeSeries([(0, True), (0.1, True), (0.2, False)]),
|
||||
}
|
||||
|
||||
"""
|
||||
class TestSTLRobustness(unittest.TestCase):
|
||||
@params(ex1, ex2, ex3, ex4, ex5, ex6)
|
||||
def test_lex_synth(self, phi_str, order, ranges, polarity, val):
|
||||
phi = stl.parse(phi_str)
|
||||
val2 = stl.synth.lex_param_project(
|
||||
phi, x, order=order, ranges=ranges, polarity=polarity)
|
||||
|
||||
# check that the valuations are almost the same
|
||||
for var in order:
|
||||
self.assertAlmostEqual(val2[var], val[var], delta=0.01)
|
||||
"""
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
import stl
|
||||
import stl.utils
|
||||
import pandas as pd
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
ex1 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"})
|
||||
ex2 = ("G[0, c?](x > a?)", {"a?", "c?"})
|
||||
ex3 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"})
|
||||
ex4 = ("F[b?, 1]G[0, c?](x > a?)", "F[2, 1]G[0, 3](x > 1)")
|
||||
ex5 = ("G[0, c?](x > a?)", "G[0, 3](x > 1)")
|
||||
|
||||
val = {"a?": 1.0, "b?": 2.0, "c?": 3.0}
|
||||
"""
|
||||
class TestSTLUtils(unittest.TestCase):
|
||||
@params(ex1, ex2, ex3)
|
||||
def test_param_lens(self, phi_str, params):
|
||||
phi = stl.parse(phi_str)
|
||||
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), params)
|
||||
|
||||
@params(ex4, ex5)
|
||||
def test_set_params(self, phi_str, phi2_str):
|
||||
phi = stl.parse(phi_str)
|
||||
phi2 = stl.parse(phi2_str)
|
||||
phi = stl.utils.set_params(phi, val)
|
||||
|
||||
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set())
|
||||
self.assertEqual(phi, phi2)
|
||||
|
||||
@params(("x > 5", 1), ("~(x)", 2), ("(F[0,1](x)) & (~(G[0, 2](y)))", 6))
|
||||
def test_walk(self, phi_str, l):
|
||||
self.assertEqual(l, len(list(stl.walk(stl.parse(phi_str)))))
|
||||
|
||||
@params(([], False, False),([int], True, False), ([int, bool], True, True))
|
||||
def test_type_pred(self, types, b1, b2):
|
||||
pred = stl.utils.type_pred(*types)
|
||||
self.assertFalse(pred(None))
|
||||
self.assertEqual(pred(1), b1)
|
||||
self.assertEqual(pred(True), b2)
|
||||
|
||||
@params(("(F[0,1]G[0, 4]((x > 3) or (y < 4))) and (x < 3)", 2))
|
||||
def test_vars_in_phi(self, phi_str, l):
|
||||
phi = stl.parse(phi_str)
|
||||
self.assertEqual(len(stl.utils.vars_in_phi(phi)), l)
|
||||
|
||||
@params(("(F[0,1]G[0, 4]((x > 3) or (y < 4))) and (x < 3)", 3))
|
||||
def test_terms_lens(self, phi_str, l):
|
||||
phi = stl.parse(phi_str)
|
||||
l2 = len(stl.terms_lens(phi).get_all())
|
||||
self.assertEqual(l, l2)
|
||||
|
||||
|
||||
@params(("(F[0,1]G[0, 4]((x > 3) | (y < 4))) & (x < 3)", 7, 12))
|
||||
def test_f_neg_or_canonical_form(self, phi_str, pre_l, post_l):
|
||||
phi = stl.parse(phi_str)
|
||||
pre_l2 = len(list(stl.walk(phi)))
|
||||
self.assertEqual(pre_l, pre_l2)
|
||||
post_l2 = len(list(stl.walk(stl.utils.f_neg_or_canonical_form(phi))))
|
||||
self.assertEqual(post_l, post_l2)
|
||||
|
||||
def test_andf(self):
|
||||
phi = stl.parse("x")
|
||||
self.assertEqual(phi, stl.andf(phi))
|
||||
|
||||
def test_orf(self):
|
||||
phi = stl.parse("x")
|
||||
self.assertEqual(phi, stl.orf(phi))
|
||||
|
||||
def test_inline_context(self):
|
||||
context = {
|
||||
stl.parse("x"): stl.parse("(z) & (y)"),
|
||||
stl.parse("z"): stl.parse("y - x > 4")
|
||||
}
|
||||
context2 = {
|
||||
stl.parse("x"): stl.parse("x"),
|
||||
}
|
||||
phi = stl.parse("x")
|
||||
self.assertEqual(stl.utils.inline_context(phi, {}), phi)
|
||||
self.assertEqual(stl.utils.inline_context(phi, context),
|
||||
stl.parse("(y - x > 4) & (y)"))
|
||||
|
||||
phi2 = stl.parse("((x) & (z)) | (y)")
|
||||
self.assertEqual(stl.utils.inline_context(phi2, context),
|
||||
stl.parse("((y - x > 4) & (y) & (y - x > 4)) | (y)"))
|
||||
|
||||
# def test_to_from_mtl(self):
|
||||
# raise NotImplementedError
|
||||
|
||||
# def test_get_polarity(self):
|
||||
# raise NotImplementedError
|
||||
|
||||
# def test_canonical_polarity(self):
|
||||
# raise NotImplementedError
|
||||
"""
|
||||
18
stl/utils.py
18
stl/utils.py
|
|
@ -1,18 +1,18 @@
|
|||
from typing import List, Type, Dict, Mapping, T, TypeVar
|
||||
from collections import deque
|
||||
import operator as op
|
||||
from collections import deque
|
||||
from functools import reduce
|
||||
from typing import Dict, List, Mapping, T, Type, TypeVar
|
||||
|
||||
import lenses
|
||||
from lenses import lens
|
||||
import funcy as fn
|
||||
import sympy
|
||||
import traces
|
||||
|
||||
import lenses
|
||||
import stl.ast
|
||||
from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg,
|
||||
AtomicPred, Param, AST)
|
||||
from stl.types import STL, STL_Generator, MTL
|
||||
from lenses import lens
|
||||
from stl.ast import (AST, And, AtomicPred, F, G, Interval, LinEq, NaryOpSTL,
|
||||
Neg, Or, Param)
|
||||
from stl.types import MTL, STL, STL_Generator
|
||||
|
||||
Lens = TypeVar('Lens')
|
||||
|
||||
|
|
@ -86,9 +86,7 @@ def param_lens(phi: STL) -> Lens:
|
|||
]
|
||||
return (x for x in candidates if isinstance(x.get()(leaf), Param))
|
||||
|
||||
return ast_lens(
|
||||
phi, pred=type_pred(LinEq, F, G),
|
||||
focus_lens=focus_lens)
|
||||
return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens)
|
||||
|
||||
|
||||
def set_params(phi, val) -> STL:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue