reimplement + test implicit validity domain function
This commit is contained in:
parent
fd555661a0
commit
d8bdab4e6a
5 changed files with 45 additions and 22 deletions
|
|
@ -1,8 +1,8 @@
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
from stl.utils import alw, env, andf, orf
|
|
||||||
from stl.ast import TOP, BOT
|
from stl.ast import TOP, BOT
|
||||||
from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg,
|
from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg,
|
||||||
Var, AtomicPred, Until)
|
Var, AtomicPred, Until)
|
||||||
from stl.parser import parse
|
from stl.parser import parse
|
||||||
from stl.fastboolean_eval import pointwise_sat
|
from stl.fastboolean_eval import pointwise_sat
|
||||||
from stl.types import STL
|
from stl.types import STL
|
||||||
|
from stl.utils import alw, env, andf, orf
|
||||||
|
|
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
from stl import pointwise_sat
|
|
||||||
|
|
||||||
|
|
||||||
def ordered_evaluator(phi):
|
|
||||||
params = {ap.name for ap in phi.params}
|
|
||||||
order = tuple(params)
|
|
||||||
|
|
||||||
def vec_to_dict(theta):
|
|
||||||
return {k: v for k, v in zip(order, theta)}
|
|
||||||
|
|
||||||
def eval_phi(theta, x):
|
|
||||||
return pointwise_sat(phi.set_params(vec_to_dict(theta)))(x, 0)
|
|
||||||
|
|
||||||
return eval_phi, order
|
|
||||||
|
|
@ -112,3 +112,11 @@ def test_fastboolean_equiv(phi):
|
||||||
stl_eval3 = stl.fastboolean_eval.pointwise_sat(~stl.alw(~phi, lo=0, hi=4))
|
stl_eval3 = stl.fastboolean_eval.pointwise_sat(~stl.alw(~phi, lo=0, hi=4))
|
||||||
stl_eval4 = stl.fastboolean_eval.pointwise_sat(stl.env(phi, lo=0, hi=4))
|
stl_eval4 = stl.fastboolean_eval.pointwise_sat(stl.env(phi, lo=0, hi=4))
|
||||||
assert stl_eval4(x, 0) == stl_eval3(x, 0)
|
assert stl_eval4(x, 0) == stl_eval3(x, 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_implicit_validity_domain_rigid():
|
||||||
|
phi = stl.parse('G[0, a?](x > b?)')
|
||||||
|
vals = {'a?': 3, 'b?': 20}
|
||||||
|
stl_eval = stl.pointwise_sat(phi.set_params(vals))
|
||||||
|
oracle, order = stl.utils.implicit_validity_domain(phi, x)
|
||||||
|
assert stl_eval(x, 0) == oracle([vals.get(k) for k in order])
|
||||||
|
|
|
||||||
|
|
@ -43,3 +43,21 @@ def test_inline_context_rigid():
|
||||||
def test_inline_context(phi):
|
def test_inline_context(phi):
|
||||||
phi2 = phi.inline_context(CONTEXT)
|
phi2 = phi.inline_context(CONTEXT)
|
||||||
assert not (APS & phi2.atomic_predicates)
|
assert not (APS & phi2.atomic_predicates)
|
||||||
|
|
||||||
|
|
||||||
|
def test_linear_stl_lipschitz_rigid():
|
||||||
|
phi = stl.parse('(x + 3y - 4z < 3)')
|
||||||
|
assert stl.utils.linear_stl_lipschitz(phi) == (8)
|
||||||
|
|
||||||
|
|
||||||
|
@given(SignalTemporalLogicStrategy, SignalTemporalLogicStrategy)
|
||||||
|
def test_linear_stl_lipschitz(phi1, phi2):
|
||||||
|
lip1 = stl.utils.linear_stl_lipschitz(phi1)
|
||||||
|
lip2 = stl.utils.linear_stl_lipschitz(phi2)
|
||||||
|
phi3 = phi1 | phi2
|
||||||
|
assert stl.utils.linear_stl_lipschitz(phi3) == max(lip1, lip2)
|
||||||
|
|
||||||
|
|
||||||
|
@given(SignalTemporalLogicStrategy, SignalTemporalLogicStrategy)
|
||||||
|
def test_timed_until_smoke_test(phi1, phi2):
|
||||||
|
stl.utils.timed_until(phi1, phi2, lo=2, hi=20)
|
||||||
|
|
|
||||||
25
stl/utils.py
25
stl/utils.py
|
|
@ -42,12 +42,15 @@ def f_neg_or_canonical_form(phi: STL) -> STL:
|
||||||
|
|
||||||
|
|
||||||
def _lineq_lipschitz(lineq):
|
def _lineq_lipschitz(lineq):
|
||||||
return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect()))
|
return sum(map(abs, bind(lineq).terms.Each().coeff.collect()))
|
||||||
|
|
||||||
|
|
||||||
def linear_stl_lipschitz(phi):
|
def linear_stl_lipschitz(phi):
|
||||||
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
||||||
return float(max(map(_lineq_lipschitz, phi.lineqs)))
|
if any(isinstance(psi, (AtomicPred, _Top, _Bot)) for psi in phi.walk()):
|
||||||
|
return float('inf')
|
||||||
|
|
||||||
|
return float(max(map(_lineq_lipschitz, phi.lineqs), default=float('inf')))
|
||||||
|
|
||||||
|
|
||||||
op_lookup = {
|
op_lookup = {
|
||||||
|
|
@ -59,11 +62,6 @@ op_lookup = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_times(x):
|
|
||||||
times = set.union(*({t for t, _ in v.items()} for v in x.values()))
|
|
||||||
return sorted(times)
|
|
||||||
|
|
||||||
|
|
||||||
def const_trace(x, start=0):
|
def const_trace(x, start=0):
|
||||||
return traces.TimeSeries([(start, x)], domain=traces.Domain(start, oo))
|
return traces.TimeSeries([(start, x)], domain=traces.Domain(start, oo))
|
||||||
|
|
||||||
|
|
@ -87,6 +85,19 @@ def eval_lineqs(phi, x):
|
||||||
return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs}
|
return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs}
|
||||||
|
|
||||||
|
|
||||||
|
def implicit_validity_domain(phi, trace):
|
||||||
|
params = {ap.name for ap in phi.params}
|
||||||
|
order = tuple(params)
|
||||||
|
|
||||||
|
def vec_to_dict(theta):
|
||||||
|
return {k: v for k, v in zip(order, theta)}
|
||||||
|
|
||||||
|
def oracle(theta):
|
||||||
|
return stl.pointwise_sat(phi.set_params(vec_to_dict(theta)))(trace, 0)
|
||||||
|
|
||||||
|
return oracle, order
|
||||||
|
|
||||||
|
|
||||||
# EDSL
|
# EDSL
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue