diff --git a/stl/smooth_robustness.py b/stl/smooth_robustness.py index 6fcf669..4082489 100644 --- a/stl/smooth_robustness.py +++ b/stl/smooth_robustness.py @@ -13,7 +13,7 @@ from stl.ast import t_sym from stl.utils import walk from stl.robustness import op_lookup -Param = namedtuple("Param", ["L", "h", "B", "id_map"]) +Param = namedtuple("Param", ["L", "h", "B", "id_map", "eps"]) @singledispatch def node_base(_, _1, _2): @@ -36,17 +36,37 @@ def sample_rate(eps, L): def admissible_params(phi, eps, L): + h = sample_rate(eps, L), + B = max(node_base(n, eps, L) for n in walk(phi)), + return B, h + +def symbolic_params(phi, eps, L): return Param( - L=L, - h=sample_rate(eps, L), - B=max(node_base(n, eps, L) for n in walk(phi)), - id_map={n:i for i, n in enumerate(walk(phi))} + L=sym.Symbol("L"), + h=sym.Symbol("h"), + B="B", + id_map={n:i for i, n in enumerate(walk(phi))}, + eps=sym.symbol("eps") ) -def smooth_robustness(phi, eps, L): - p = admissible_params(phi, eps, L) +def smooth_robustness(phi, *, L=None, eps=None): + # TODO: Return symbollic formula if flag + p = symbolic_params(phi, eps, L) lo, hi = beta(phi, p), alpha(phi, p) + subs = {} + if L is not None: + subs[p.L] = L + if eps is not None: + subs[p.eps] = eps + if L is not None and eps is not None: + B, h = admissible_params(phi, eps, L) + subs[p.B] = B + subs[p.h] = h + lo, hi = lo.subs(subs), hi.subs(subs) + else: + B = p.B + return sym.log(lo, B), sym.log(hi, B)