From a1ca4c6579b6b4f6bfb57b561e38406d0339c502 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Thu, 26 Oct 2017 22:00:03 -0700 Subject: [PATCH] yapf + pylint + add style checks to tests --- .gitignore | 4 +- .travis.yml | 2 +- requirements.txt | 15 ++++++- setup.py | 35 +++++++-------- stl/__init__.py | 6 +-- stl/ast.py | 11 +++-- stl/boolean_eval.py | 37 +++++++++------- stl/fastboolean_eval.py | 62 ++++++++++---------------- stl/hypothesis.py | 26 +++++------ stl/load.py | 3 +- stl/parser.py | 11 ++--- stl/synth.py | 38 ---------------- stl/test_ast.py | 4 +- stl/test_boolean_eval.py | 19 ++++---- stl/test_parser.py | 4 +- stl/test_synth.py | 38 ---------------- stl/test_utils.py | 94 ---------------------------------------- stl/utils.py | 20 ++++----- 18 files changed, 130 insertions(+), 299 deletions(-) delete mode 100644 stl/synth.py delete mode 100644 stl/test_synth.py delete mode 100644 stl/test_utils.py diff --git a/.gitignore b/.gitignore index 4e904c1..3382d0b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ dist/* .cache/* -.hypothesis/* \ No newline at end of file +.hypothesis/* +.coverage* +htmlcov/* \ No newline at end of file diff --git a/.travis.yml b/.travis.yml index 078714c..cb48ec0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,4 +2,4 @@ python: - "3.6" install: "pip install -r requirements.txt" - script: pytest \ No newline at end of file + script: pytest --flake8 --cov=stl -x --cov-report=html --isort diff --git a/requirements.txt b/requirements.txt index dbc7c12..39e06ee 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ -e git://github.com/mvcisback/hypothesis-cfg@master#egg=hypothesis-cfg +-e git://github.com/mvcisback/multidim-threshold@master#egg=multidim-threshold bitarray==0.8.1 funcy==1.9.1 lenses==0.3.0 @@ -7,4 +8,16 @@ parsimonious==0.7.0 sympy==1.0 traces==0.3.1 pytest==3.2.2 -hypothesis==3.30.3 +hypothesis==3.32.1 + +pytest==3.2.3 +pytest-bpdb==0.1.4 +pytest-cache==1.0 +pytest-colordots==1.1 +pytest-cov==2.5.1 +pytest-flake8==0.9 +pytest-forked==0.2 +pytest-isort==0.1.0 +pytest-pep8==1.0.6 +pytest-xdist==1.20.1 +python-dateutil==2.6.1 diff --git a/setup.py b/setup.py index 8f95cc6..cf300b1 100644 --- a/setup.py +++ b/setup.py @@ -1,19 +1,20 @@ -from setuptools import setup, find_packages +from setuptools import find_packages, setup -setup(name='py-stl', - version='0.2', - description='TODO', - url='http://github.com/mvcisback/py-stl', - author='Marcell Vazquez-Chanlatte', - author_email='marcell.vc@eecs.berkeley.edu', - license='MIT', - install_requires=[ - 'funcy', - 'parsimonious', - 'lenses', - 'sympy', - 'bitarray', - 'traces', - ], - packages=find_packages(), +setup( + name='py-stl', + version='0.2', + description='TODO', + url='http://github.com/mvcisback/py-stl', + author='Marcell Vazquez-Chanlatte', + author_email='marcell.vc@eecs.berkeley.edu', + license='MIT', + install_requires=[ + 'funcy', + 'parsimonious', + 'lenses', + 'sympy', + 'bitarray', + 'traces', + ], + packages=find_packages(), ) diff --git a/stl/__init__.py b/stl/__init__.py index 86c7db8..12f9229 100644 --- a/stl/__init__.py +++ b/stl/__init__.py @@ -1,9 +1,9 @@ +# flake8: noqa from stl.utils import terms_lens, lineq_lens, walk, and_or_lens from stl.utils import alw, env, andf, orf from stl.ast import dt_sym, t_sym, TOP, BOT -from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, - ModalOp, Neg, Var, AtomicPred, Until) +from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, + Var, AtomicPred, Until) from stl.parser import parse from stl.fastboolean_eval import pointwise_sat -from stl.synth import lex_param_project from stl.types import STL diff --git a/stl/ast.py b/stl/ast.py index 70d25d0..68da66c 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -1,9 +1,7 @@ # -*- coding: utf-8 -*- # TODO: supress + given a + (-b). i.e. want a - b -from collections import namedtuple, deque -from itertools import repeat -from enum import Enum +from collections import namedtuple import funcy as fn from sympy import Symbol @@ -13,7 +11,9 @@ t_sym = Symbol('t', positive=True) def flatten_binary(phi, op, dropT, shortT): - f = lambda x: x.args if isinstance(x, op) else [x] + def f(x): + return x.args if isinstance(x, op) else [x] + args = [arg for arg in phi.args if arg is not dropT] if any(arg is shortT for arg in args): @@ -153,7 +153,7 @@ class And(NaryOpSTL): class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST): __slots__ = () OP = '?' - + def __repr__(self): return f"{self.OP}{self.interval}({self.arg})" @@ -231,7 +231,6 @@ class Param(namedtuple('Param', ['name']), AST): def __repr__(self): return self.name - def __hash__(self): # TODO: compute hash based on contents return hash(repr(self)) diff --git a/stl/boolean_eval.py b/stl/boolean_eval.py index 9a3d207..7dbbd4f 100644 --- a/stl/boolean_eval.py +++ b/stl/boolean_eval.py @@ -1,38 +1,42 @@ # TODO: figure out how to deduplicate this with robustness # - Abstract as working on distributive lattice -from functools import singledispatch import operator as op +from functools import singledispatch import funcy as fn -from lenses import lens -import stl.ast import stl +import stl.ast +from lenses import lens oo = float('inf') + def pointwise_sat(phi): ap_names = [z.id.name for z in stl.utils.AP_lens(phi).Each().collect()] + def _eval_stl(x, t): evaluated = stl.utils.eval_lineqs(phi, x) evaluated.update(fn.project(x, ap_names)) return eval_stl(phi)(evaluated, t) + return _eval_stl + @singledispatch def eval_stl(stl): raise NotImplementedError @eval_stl.register(stl.Or) -def _(phi): +def eval_stl_or(phi): fs = [eval_stl(arg) for arg in phi.args] - return lambda x, t: any(f(x,t) for f in fs) + return lambda x, t: any(f(x, t) for f in fs) @eval_stl.register(stl.And) -def _(stl): +def eval_stl_and(stl): fs = [eval_stl(arg) for arg in stl.args] return lambda x, t: all(f(x, t) for f in fs) @@ -57,13 +61,14 @@ def get_times(x, tau, lo=None, hi=None): @eval_stl.register(stl.Until) -def _(stl): +def eval_stl_until(stl): def _until(x, t): f1, f2 = eval_stl(stl.arg1), eval_stl(stl.arg2) for tau in get_times(x, t): if not f1(x, tau): return f2(x, tau) return False + return _until @@ -73,24 +78,24 @@ def eval_unary_temporal_op(phi, always=True): if lo > hi: retval = True if always else False return lambda x, t: retval - f = eval_stl(phi.arg) + f = eval_stl(phi.arg) if hi == lo: return lambda x, t: f(x, t) return lambda x, t: fold(f(x, tau) for tau in get_times(x, t, lo, hi)) @eval_stl.register(stl.F) -def _(phi): +def eval_stl_f(phi): return eval_unary_temporal_op(phi, always=False) @eval_stl.register(stl.G) -def _(phi): +def eval_stl_g(phi): return eval_unary_temporal_op(phi, always=True) @eval_stl.register(stl.Neg) -def _(stl): +def eval_stl_neg(stl): f = eval_stl(stl.arg) return lambda x, t: not f(x, t) @@ -105,22 +110,22 @@ op_lookup = { @eval_stl.register(stl.AtomicPred) -def _(stl): +def eval_stl_ap(stl): return lambda x, t: x[str(stl.id)][t] @eval_stl.register(type(stl.TOP)) -def _(_): +def eval_stl_top(_): return lambda *_: True @eval_stl.register(type(stl.BOT)) -def _(_): +def eval_stl_bot(_): return lambda *_: False @eval_stl.register(stl.LinEq) -def _(lineq): +def eval_stl_lineq(lineq): return lambda x, t: x[lineq][t] @@ -130,4 +135,4 @@ def eval_terms(lineq, x, t): def eval_term(term, x, t): - return float(term.coeff)*x[term.id.name][t] + return float(term.coeff) * x[term.id][t] diff --git a/stl/fastboolean_eval.py b/stl/fastboolean_eval.py index 09ee957..b9de5d9 100644 --- a/stl/fastboolean_eval.py +++ b/stl/fastboolean_eval.py @@ -1,19 +1,22 @@ -from functools import singledispatch, reduce +from functools import reduce, singledispatch from operator import and_, or_ from bitarray import bitarray import stl.ast -from stl.boolean_eval import eval_terms, op_lookup, get_times +from stl.boolean_eval import eval_terms, get_times, op_lookup + def pointwise_sat(stl): f = pointwise_satf(stl) return lambda x, t: bool(int(f(x, [t]).to01())) + @singledispatch def pointwise_satf(stl): raise NotImplementedError + def bool_op(stl, conjunction=False): binop = and_ if conjunction else or_ fs = [pointwise_satf(arg) for arg in stl.args] @@ -21,78 +24,59 @@ def bool_op(stl, conjunction=False): @pointwise_satf.register(stl.Or) -def _(stl): +def pointwise_satf_or(stl): return bool_op(stl, conjunction=False) @pointwise_satf.register(stl.And) -def _(stl): +def pointwise_satf_and(stl): return bool_op(stl, conjunction=True) def temporal_op(stl, lo, hi, conjunction=False): fold = bitarray.all if conjunction else bitarray.any f = pointwise_satf(stl.arg) - def sat_comp(x,t): + + def sat_comp(x, t): return bitarray(fold(f(x, get_times(x, tau, lo, hi))) for tau in t) + return sat_comp -@pointwise_satf.register(stl.Until) -def _(stl): - f1, f2 = pointwise_satf(stl.arg1), pointwise_satf(stl.arg2) - def __until(x, t): - f1, f2 = pointwise_satf(stl.arg1), pointwise_satf(stl.arg2) - - state = False - times = get_times(x, t[0]) - for phi, tau in zip(reversed(f1(x, times)), reversed(times)): - if not phi: - state = f2(x, [tau]) - - if tau in t: - yield state - - def _until(x, t): - retval = bitarray(__until(x, t)) - retval.reverse() - return retval - - return _until - - @pointwise_satf.register(stl.F) -def _(stl): +def pointwise_satf_f(stl): lo, hi = stl.interval return temporal_op(stl, lo, hi, conjunction=False) @pointwise_satf.register(stl.G) -def _(stl): +def pointwise_satf_g(stl): lo, hi = stl.interval return temporal_op(stl, lo, hi, conjunction=True) @pointwise_satf.register(stl.Neg) -def _(stl): - return lambda x,t: ~pointwise_satf(stl.arg)(x, t) +def pointwise_satf_neg(stl): + return lambda x, t: ~pointwise_satf(stl.arg)(x, t) @pointwise_satf.register(stl.AtomicPred) -def _(phi): +def pointwise_satf_(phi): return lambda x, t: bitarray(x[str(phi.id)][tau] for tau in t) + @pointwise_satf.register(type(stl.TOP)) -def _(_): - return lambda _, t: bitarray([True]*len(t)) +def pointwise_satf_top(_): + return lambda _, t: bitarray([True] * len(t)) @pointwise_satf.register(type(stl.BOT)) -def _(_): - return lambda _, t: bitarray([False]*len(t)) +def pointwise_satf_bot(_): + return lambda _, t: bitarray([False] * len(t)) @pointwise_satf.register(stl.LinEq) -def _(stl): - op = lambda a: op_lookup[stl.op](a, stl.const) +def pointwise_satf_lineq(stl): + def op(a): + return op_lookup[stl.op](a, stl.const) return lambda x, t: bitarray(op(eval_terms(stl, x, tau)) for tau in t) diff --git a/stl/hypothesis.py b/stl/hypothesis.py index 02b75eb..77b4cd8 100644 --- a/stl/hypothesis.py +++ b/stl/hypothesis.py @@ -1,23 +1,23 @@ -from hypothesis_cfg import ContextFreeGrammarStrategy - import hypothesis.strategies as st -from hypothesis.searchstrategy.strategies import SearchStrategy +from hypothesis_cfg import ContextFreeGrammarStrategy import stl - GRAMMAR = { - 'phi': (('Unary', '(', 'phi', ')'), - ('(', 'phi', ')', 'Binary', '(', 'phi', ')'), - ('AP',)), - 'Unary': (('~',), ('G',), ('F',), ('X',)), - 'Binary': ((' | ',), (' & ',), (' U ',)), - 'AP': (('AP1',), ('AP2',), ('AP3',), ('AP4',), ('AP5',)), + 'phi': (('Unary', '(', 'phi', ')'), ('(', 'phi', ')', 'Binary', '(', 'phi', + ')'), ('AP', )), + 'Unary': (('~', ), ('G', ), ('F', ), ('X', )), + 'Binary': ((' | ', ), (' & ', ), (' U ', )), + 'AP': (('AP1', ), ('AP2', ), ('AP3', ), ('AP4', ), ('AP5', )), } + def to_stl(term): return stl.parse(''.join(term)) -SignalTemporalLogicStrategy = st.builds( - to_stl, ContextFreeGrammarStrategy( - GRAMMAR, max_length=25, start='phi')) + +SignalTemporalLogicStrategy = st.builds(to_stl, + ContextFreeGrammarStrategy( + GRAMMAR, + max_length=25, + start='phi')) diff --git a/stl/load.py b/stl/load.py index 5e43a8d..89dc019 100644 --- a/stl/load.py +++ b/stl/load.py @@ -1,4 +1,5 @@ -from traces import TimeSeries, Domain +from traces import Domain, TimeSeries + def from_pandas(df, compact=True): '''TODO''' diff --git a/stl/parser.py b/stl/parser.py index 7a5dada..34f633e 100644 --- a/stl/parser.py +++ b/stl/parser.py @@ -5,15 +5,11 @@ # TODO: Allow -x = -1*x from functools import partialmethod -from collections import namedtuple -import operator as op -from parsimonious import Grammar, NodeVisitor -from funcy import flatten from lenses import lens - +from parsimonious import Grammar, NodeVisitor from stl import ast -from stl.utils import implies, xor, iff, env, alw +from stl.utils import alw, env, iff, implies, xor STL_GRAMMAR = Grammar(u''' phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and @@ -60,9 +56,9 @@ __ = ~r"\s"* EOL = "\\n" ''') - oo = float('inf') + class STLVisitor(NodeVisitor): def __init__(self, H=oo): super().__init__() @@ -137,7 +133,6 @@ class STLVisitor(NodeVisitor): c = coeffs[0] if coeffs else 1 return ast.Var(coeff=c, id=iden) - def visit_terms(self, _, children): if isinstance(children[0], list): term, _1, sgn, _2, terms = children[0] diff --git a/stl/synth.py b/stl/synth.py deleted file mode 100644 index 1637009..0000000 --- a/stl/synth.py +++ /dev/null @@ -1,38 +0,0 @@ -import operator as op - -from stl.utils import set_params, param_lens -from stl.boolean_eval import pointwise_sat - -from lenses import lens - -def binsearch(stleval, *, tol=1e-3, lo, hi, polarity): - """Only run search if tightest robustness was positive.""" - # Early termination via bounds checks - if polarity and stleval(lo): - return lo - elif not polarity and stleval(hi): - return hi - - while hi - lo > tol: - mid = lo + (hi - lo) / 2 - r = stleval(mid) - lo, hi = (mid, hi) if r ^ polarity else (lo, mid) - - # Want satisifiable formula - return hi if polarity else lo - - -def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3): - val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order} - p_lens = param_lens(stl) - def stleval_fact(var, val): - l = lens(val)[var] - return lambda p: pointwise_sat(set_params(stl, l.set(p)))(x, 0) - - for var in order: - stleval = stleval_fact(var, val) - lo, hi = ranges[var] - param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var]) - val[var] = param - - return val diff --git a/stl/test_ast.py b/stl/test_ast.py index d707387..cb50c6e 100644 --- a/stl/test_ast.py +++ b/stl/test_ast.py @@ -1,6 +1,8 @@ -import stl import unittest +import stl + + class TestSTLAST(unittest.TestCase): def test_and(self): phi = stl.parse("x") diff --git a/stl/test_boolean_eval.py b/stl/test_boolean_eval.py index b3a96c1..a77ff5c 100644 --- a/stl/test_boolean_eval.py +++ b/stl/test_boolean_eval.py @@ -1,13 +1,11 @@ +import hypothesis.strategies as st +import traces + +from hypothesis import given + import stl import stl.boolean_eval import stl.fastboolean_eval -import traces -import unittest -from sympy import Symbol - -import hypothesis.strategies as st -from hypothesis import given, note, assume, example - """ TODO: property based test that fasteval should be the same as slow @@ -28,7 +26,11 @@ x = { "A": traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)]), "B": traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)]), "C": traces.TimeSeries([(0, True), (0.1, True), (0.2, False)]), - 'D': traces.TimeSeries({0.0: 2, 13.8: 3, 19.7: 2}), + 'D': traces.TimeSeries({ + 0.0: 2, + 13.8: 3, + 19.7: 2 + }), } @@ -56,4 +58,3 @@ def test_temporal_identities(phi): stl_eval3 = stl.fastboolean_eval.pointwise_sat(~stl.alw(~phi, lo=0, hi=4)) stl_eval4 = stl.fastboolean_eval.pointwise_sat(stl.env(phi, lo=0, hi=4)) assert stl_eval4(x, 0) == stl_eval3(x, 0) - diff --git a/stl/test_parser.py b/stl/test_parser.py index 7d6c644..b66c317 100644 --- a/stl/test_parser.py +++ b/stl/test_parser.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- -import stl -from hypothesis import given, event +from hypothesis import event, given +import stl from stl.hypothesis import SignalTemporalLogicStrategy diff --git a/stl/test_synth.py b/stl/test_synth.py deleted file mode 100644 index ec195e6..0000000 --- a/stl/test_synth.py +++ /dev/null @@ -1,38 +0,0 @@ -import stl -import stl.synth -import traces -import unittest -from sympy import Symbol - -oo = float('inf') - -ex1 = ("A > a?", ("a?",), {"a?": (0, 10)}, {"a?": False}, {"a?": 1}) -ex2 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 5)}, - {"a?": False, "b?": True}, {"a?": 4, "b?": 0.2}) -ex3 = ("F[0, b?](A < 0)", ("b?",), {"b?": (0, 5)}, - {"b?": True}, {"b?": 5}) -ex4 = ("G[0, b?](A < 0)", ("b?",), {"b?": (0.1, 5)}, - {"b?": False}, {"b?": 0.1}) -ex5 = ("F[0, b?](A > 0)", ("b?",), {"b?": (0.1, 5)}, - {"b?": True}, {"b?": 0.1}) -ex6 = ("(A > a?) or (A > b?)", ("a?", "b?",), {"a?": (0, 2), "b?": (0, 2)}, - {"a?": False, "b?": False}, {"a?": 2, "b?": 1}) - -x = { - "A": traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)]), - "B": traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)]), - "C": traces.TimeSeries([(0, True), (0.1, True), (0.2, False)]), -} - -""" -class TestSTLRobustness(unittest.TestCase): - @params(ex1, ex2, ex3, ex4, ex5, ex6) - def test_lex_synth(self, phi_str, order, ranges, polarity, val): - phi = stl.parse(phi_str) - val2 = stl.synth.lex_param_project( - phi, x, order=order, ranges=ranges, polarity=polarity) - - # check that the valuations are almost the same - for var in order: - self.assertAlmostEqual(val2[var], val[var], delta=0.01) -""" diff --git a/stl/test_utils.py b/stl/test_utils.py deleted file mode 100644 index 7c0c3f8..0000000 --- a/stl/test_utils.py +++ /dev/null @@ -1,94 +0,0 @@ -import stl -import stl.utils -import pandas as pd -import unittest -from sympy import Symbol - -ex1 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"}) -ex2 = ("G[0, c?](x > a?)", {"a?", "c?"}) -ex3 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"}) -ex4 = ("F[b?, 1]G[0, c?](x > a?)", "F[2, 1]G[0, 3](x > 1)") -ex5 = ("G[0, c?](x > a?)", "G[0, 3](x > 1)") - -val = {"a?": 1.0, "b?": 2.0, "c?": 3.0} -""" -class TestSTLUtils(unittest.TestCase): - @params(ex1, ex2, ex3) - def test_param_lens(self, phi_str, params): - phi = stl.parse(phi_str) - self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), params) - - @params(ex4, ex5) - def test_set_params(self, phi_str, phi2_str): - phi = stl.parse(phi_str) - phi2 = stl.parse(phi2_str) - phi = stl.utils.set_params(phi, val) - - self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set()) - self.assertEqual(phi, phi2) - - @params(("x > 5", 1), ("~(x)", 2), ("(F[0,1](x)) & (~(G[0, 2](y)))", 6)) - def test_walk(self, phi_str, l): - self.assertEqual(l, len(list(stl.walk(stl.parse(phi_str))))) - - @params(([], False, False),([int], True, False), ([int, bool], True, True)) - def test_type_pred(self, types, b1, b2): - pred = stl.utils.type_pred(*types) - self.assertFalse(pred(None)) - self.assertEqual(pred(1), b1) - self.assertEqual(pred(True), b2) - - @params(("(F[0,1]G[0, 4]((x > 3) or (y < 4))) and (x < 3)", 2)) - def test_vars_in_phi(self, phi_str, l): - phi = stl.parse(phi_str) - self.assertEqual(len(stl.utils.vars_in_phi(phi)), l) - - @params(("(F[0,1]G[0, 4]((x > 3) or (y < 4))) and (x < 3)", 3)) - def test_terms_lens(self, phi_str, l): - phi = stl.parse(phi_str) - l2 = len(stl.terms_lens(phi).get_all()) - self.assertEqual(l, l2) - - - @params(("(F[0,1]G[0, 4]((x > 3) | (y < 4))) & (x < 3)", 7, 12)) - def test_f_neg_or_canonical_form(self, phi_str, pre_l, post_l): - phi = stl.parse(phi_str) - pre_l2 = len(list(stl.walk(phi))) - self.assertEqual(pre_l, pre_l2) - post_l2 = len(list(stl.walk(stl.utils.f_neg_or_canonical_form(phi)))) - self.assertEqual(post_l, post_l2) - - def test_andf(self): - phi = stl.parse("x") - self.assertEqual(phi, stl.andf(phi)) - - def test_orf(self): - phi = stl.parse("x") - self.assertEqual(phi, stl.orf(phi)) - - def test_inline_context(self): - context = { - stl.parse("x"): stl.parse("(z) & (y)"), - stl.parse("z"): stl.parse("y - x > 4") - } - context2 = { - stl.parse("x"): stl.parse("x"), - } - phi = stl.parse("x") - self.assertEqual(stl.utils.inline_context(phi, {}), phi) - self.assertEqual(stl.utils.inline_context(phi, context), - stl.parse("(y - x > 4) & (y)")) - - phi2 = stl.parse("((x) & (z)) | (y)") - self.assertEqual(stl.utils.inline_context(phi2, context), - stl.parse("((y - x > 4) & (y) & (y - x > 4)) | (y)")) - -# def test_to_from_mtl(self): -# raise NotImplementedError - -# def test_get_polarity(self): -# raise NotImplementedError - -# def test_canonical_polarity(self): -# raise NotImplementedError -""" diff --git a/stl/utils.py b/stl/utils.py index 57a67b5..6cd3572 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -1,18 +1,18 @@ -from typing import List, Type, Dict, Mapping, T, TypeVar -from collections import deque import operator as op +from collections import deque from functools import reduce +from typing import Dict, List, Mapping, T, Type, TypeVar -import lenses -from lenses import lens import funcy as fn import sympy import traces +import lenses import stl.ast -from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg, - AtomicPred, Param, AST) -from stl.types import STL, STL_Generator, MTL +from lenses import lens +from stl.ast import (AST, And, AtomicPred, F, G, Interval, LinEq, NaryOpSTL, + Neg, Or, Param) +from stl.types import MTL, STL, STL_Generator Lens = TypeVar('Lens') @@ -37,7 +37,7 @@ def type_pred(*args: List[Type]) -> Mapping[Type, bool]: return lambda x: type(x) in ast_types -def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None, +def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None, getter=False) -> Lens: if focus_lens is None: focus_lens = lambda _: [lens] @@ -86,9 +86,7 @@ def param_lens(phi: STL) -> Lens: ] return (x for x in candidates if isinstance(x.get()(leaf), Param)) - return ast_lens( - phi, pred=type_pred(LinEq, F, G), - focus_lens=focus_lens) + return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens) def set_params(phi, val) -> STL: