diff --git a/stl/__init__.py b/stl/__init__.py index 61bd2ca..b8ef2fb 100644 --- a/stl/__init__.py +++ b/stl/__init__.py @@ -1,8 +1,8 @@ # flake8: noqa -from stl.utils import alw, env, andf, orf from stl.ast import TOP, BOT from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred, Until) from stl.parser import parse from stl.fastboolean_eval import pointwise_sat from stl.types import STL +from stl.utils import alw, env, andf, orf diff --git a/stl/featurize.py b/stl/featurize.py deleted file mode 100644 index dd8e152..0000000 --- a/stl/featurize.py +++ /dev/null @@ -1,14 +0,0 @@ -from stl import pointwise_sat - - -def ordered_evaluator(phi): - params = {ap.name for ap in phi.params} - order = tuple(params) - - def vec_to_dict(theta): - return {k: v for k, v in zip(order, theta)} - - def eval_phi(theta, x): - return pointwise_sat(phi.set_params(vec_to_dict(theta)))(x, 0) - - return eval_phi, order diff --git a/stl/test_boolean_eval.py b/stl/test_boolean_eval.py index c9afc7b..3261924 100644 --- a/stl/test_boolean_eval.py +++ b/stl/test_boolean_eval.py @@ -112,3 +112,11 @@ def test_fastboolean_equiv(phi): stl_eval3 = stl.fastboolean_eval.pointwise_sat(~stl.alw(~phi, lo=0, hi=4)) stl_eval4 = stl.fastboolean_eval.pointwise_sat(stl.env(phi, lo=0, hi=4)) assert stl_eval4(x, 0) == stl_eval3(x, 0) + + +def test_implicit_validity_domain_rigid(): + phi = stl.parse('G[0, a?](x > b?)') + vals = {'a?': 3, 'b?': 20} + stl_eval = stl.pointwise_sat(phi.set_params(vals)) + oracle, order = stl.utils.implicit_validity_domain(phi, x) + assert stl_eval(x, 0) == oracle([vals.get(k) for k in order]) diff --git a/stl/test_utils.py b/stl/test_utils.py index fe09a5a..845b1b7 100644 --- a/stl/test_utils.py +++ b/stl/test_utils.py @@ -43,3 +43,21 @@ def test_inline_context_rigid(): def test_inline_context(phi): phi2 = phi.inline_context(CONTEXT) assert not (APS & phi2.atomic_predicates) + + +def test_linear_stl_lipschitz_rigid(): + phi = stl.parse('(x + 3y - 4z < 3)') + assert stl.utils.linear_stl_lipschitz(phi) == (8) + + +@given(SignalTemporalLogicStrategy, SignalTemporalLogicStrategy) +def test_linear_stl_lipschitz(phi1, phi2): + lip1 = stl.utils.linear_stl_lipschitz(phi1) + lip2 = stl.utils.linear_stl_lipschitz(phi2) + phi3 = phi1 | phi2 + assert stl.utils.linear_stl_lipschitz(phi3) == max(lip1, lip2) + + +@given(SignalTemporalLogicStrategy, SignalTemporalLogicStrategy) +def test_timed_until_smoke_test(phi1, phi2): + stl.utils.timed_until(phi1, phi2, lo=2, hi=20) diff --git a/stl/utils.py b/stl/utils.py index c46e3f1..e65598f 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -42,12 +42,15 @@ def f_neg_or_canonical_form(phi: STL) -> STL: def _lineq_lipschitz(lineq): - return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect())) + return sum(map(abs, bind(lineq).terms.Each().coeff.collect())) def linear_stl_lipschitz(phi): """Infinity norm lipschitz bound of linear inequality predicate.""" - return float(max(map(_lineq_lipschitz, phi.lineqs))) + if any(isinstance(psi, (AtomicPred, _Top, _Bot)) for psi in phi.walk()): + return float('inf') + + return float(max(map(_lineq_lipschitz, phi.lineqs), default=float('inf'))) op_lookup = { @@ -59,11 +62,6 @@ op_lookup = { } -def get_times(x): - times = set.union(*({t for t, _ in v.items()} for v in x.values())) - return sorted(times) - - def const_trace(x, start=0): return traces.TimeSeries([(start, x)], domain=traces.Domain(start, oo)) @@ -87,6 +85,19 @@ def eval_lineqs(phi, x): return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs} +def implicit_validity_domain(phi, trace): + params = {ap.name for ap in phi.params} + order = tuple(params) + + def vec_to_dict(theta): + return {k: v for k, v in zip(order, theta)} + + def oracle(theta): + return stl.pointwise_sat(phi.set_params(vec_to_dict(theta)))(trace, 0) + + return oracle, order + + # EDSL