reimplement + test implicit validity domain function
This commit is contained in:
parent
fd555661a0
commit
d8bdab4e6a
5 changed files with 45 additions and 22 deletions
25
stl/utils.py
25
stl/utils.py
|
|
@ -42,12 +42,15 @@ def f_neg_or_canonical_form(phi: STL) -> STL:
|
|||
|
||||
|
||||
def _lineq_lipschitz(lineq):
|
||||
return sum(map(abs, bind(lineq).Each().terms.Each().coeff.collect()))
|
||||
return sum(map(abs, bind(lineq).terms.Each().coeff.collect()))
|
||||
|
||||
|
||||
def linear_stl_lipschitz(phi):
|
||||
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
||||
return float(max(map(_lineq_lipschitz, phi.lineqs)))
|
||||
if any(isinstance(psi, (AtomicPred, _Top, _Bot)) for psi in phi.walk()):
|
||||
return float('inf')
|
||||
|
||||
return float(max(map(_lineq_lipschitz, phi.lineqs), default=float('inf')))
|
||||
|
||||
|
||||
op_lookup = {
|
||||
|
|
@ -59,11 +62,6 @@ op_lookup = {
|
|||
}
|
||||
|
||||
|
||||
def get_times(x):
|
||||
times = set.union(*({t for t, _ in v.items()} for v in x.values()))
|
||||
return sorted(times)
|
||||
|
||||
|
||||
def const_trace(x, start=0):
|
||||
return traces.TimeSeries([(start, x)], domain=traces.Domain(start, oo))
|
||||
|
||||
|
|
@ -87,6 +85,19 @@ def eval_lineqs(phi, x):
|
|||
return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs}
|
||||
|
||||
|
||||
def implicit_validity_domain(phi, trace):
|
||||
params = {ap.name for ap in phi.params}
|
||||
order = tuple(params)
|
||||
|
||||
def vec_to_dict(theta):
|
||||
return {k: v for k, v in zip(order, theta)}
|
||||
|
||||
def oracle(theta):
|
||||
return stl.pointwise_sat(phi.set_params(vec_to_dict(theta)))(trace, 0)
|
||||
|
||||
return oracle, order
|
||||
|
||||
|
||||
# EDSL
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue