reimplement + test implicit validity domain function

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-11-14 10:50:25 -08:00
parent fd555661a0
commit d8bdab4e6a
5 changed files with 45 additions and 22 deletions

View file

@ -1,8 +1,8 @@
# flake8: noqa # flake8: noqa
from stl.utils import alw, env, andf, orf
from stl.ast import TOP, BOT from stl.ast import TOP, BOT
from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg,
Var, AtomicPred, Until) Var, AtomicPred, Until)
from stl.parser import parse from stl.parser import parse
from stl.fastboolean_eval import pointwise_sat from stl.fastboolean_eval import pointwise_sat
from stl.types import STL from stl.types import STL
from stl.utils import alw, env, andf, orf

View file

@ -1,14 +0,0 @@
from stl import pointwise_sat
def ordered_evaluator(phi):
params = {ap.name for ap in phi.params}
order = tuple(params)
def vec_to_dict(theta):
return {k: v for k, v in zip(order, theta)}
def eval_phi(theta, x):
return pointwise_sat(phi.set_params(vec_to_dict(theta)))(x, 0)
return eval_phi, order

View file

@ -112,3 +112,11 @@ def test_fastboolean_equiv(phi):
stl_eval3 = stl.fastboolean_eval.pointwise_sat(~stl.alw(~phi, lo=0, hi=4)) stl_eval3 = stl.fastboolean_eval.pointwise_sat(~stl.alw(~phi, lo=0, hi=4))
stl_eval4 = stl.fastboolean_eval.pointwise_sat(stl.env(phi, lo=0, hi=4)) stl_eval4 = stl.fastboolean_eval.pointwise_sat(stl.env(phi, lo=0, hi=4))
assert stl_eval4(x, 0) == stl_eval3(x, 0) assert stl_eval4(x, 0) == stl_eval3(x, 0)
def test_implicit_validity_domain_rigid():
phi = stl.parse('G[0, a?](x > b?)')
vals = {'a?': 3, 'b?': 20}
stl_eval = stl.pointwise_sat(phi.set_params(vals))
oracle, order = stl.utils.implicit_validity_domain(phi, x)
assert stl_eval(x, 0) == oracle([vals.get(k) for k in order])

View file

@ -43,3 +43,21 @@ def test_inline_context_rigid():
def test_inline_context(phi): def test_inline_context(phi):
phi2 = phi.inline_context(CONTEXT) phi2 = phi.inline_context(CONTEXT)
assert not (APS & phi2.atomic_predicates) assert not (APS & phi2.atomic_predicates)
def test_linear_stl_lipschitz_rigid():
phi = stl.parse('(x + 3y - 4z < 3)')
assert stl.utils.linear_stl_lipschitz(phi) == (8)
@given(SignalTemporalLogicStrategy, SignalTemporalLogicStrategy)
def test_linear_stl_lipschitz(phi1, phi2):
lip1 = stl.utils.linear_stl_lipschitz(phi1)
lip2 = stl.utils.linear_stl_lipschitz(phi2)
phi3 = phi1 | phi2
assert stl.utils.linear_stl_lipschitz(phi3) == max(lip1, lip2)
@given(SignalTemporalLogicStrategy, SignalTemporalLogicStrategy)
def test_timed_until_smoke_test(phi1, phi2):
stl.utils.timed_until(phi1, phi2, lo=2, hi=20)

View file

@ -42,12 +42,15 @@ def f_neg_or_canonical_form(phi: STL) -> STL:
def _lineq_lipschitz(lineq): def _lineq_lipschitz(lineq):
return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect())) return sum(map(abs, bind(lineq).terms.Each().coeff.collect()))
def linear_stl_lipschitz(phi): def linear_stl_lipschitz(phi):
"""Infinity norm lipschitz bound of linear inequality predicate.""" """Infinity norm lipschitz bound of linear inequality predicate."""
return float(max(map(_lineq_lipschitz, phi.lineqs))) if any(isinstance(psi, (AtomicPred, _Top, _Bot)) for psi in phi.walk()):
return float('inf')
return float(max(map(_lineq_lipschitz, phi.lineqs), default=float('inf')))
op_lookup = { op_lookup = {
@ -59,11 +62,6 @@ op_lookup = {
} }
def get_times(x):
times = set.union(*({t for t, _ in v.items()} for v in x.values()))
return sorted(times)
def const_trace(x, start=0): def const_trace(x, start=0):
return traces.TimeSeries([(start, x)], domain=traces.Domain(start, oo)) return traces.TimeSeries([(start, x)], domain=traces.Domain(start, oo))
@ -87,6 +85,19 @@ def eval_lineqs(phi, x):
return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs} return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs}
def implicit_validity_domain(phi, trace):
params = {ap.name for ap in phi.params}
order = tuple(params)
def vec_to_dict(theta):
return {k: v for k, v in zip(order, theta)}
def oracle(theta):
return stl.pointwise_sat(phi.set_params(vec_to_dict(theta)))(trace, 0)
return oracle, order
# EDSL # EDSL