diff --git a/stl/robustness.py b/stl/robustness.py index c9475c0..14963ba 100644 --- a/stl/robustness.py +++ b/stl/robustness.py @@ -66,5 +66,7 @@ def eval_terms(lineq, x, t): def eval_term(x, t): - # TODO(lift interpolation much higher) - return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name]) + def _eval_term(term): + coeff = float(term.coeff) if term.coeff.is_number else term.coeff + return coeff*np.interp(t, x.index, x[term.id.name]) + return _eval_term diff --git a/stl/utils.py b/stl/utils.py index 6f169c9..5f7b185 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -33,7 +33,9 @@ def type_pred(*args:List[Type]) -> Mapping[Type, bool]: def _child_lens(psi:STL, focus:Lens) -> STL_Generator: if psi is None: return - if isinstance(psi, NaryOpSTL): + elif psi is stl.TOP or psi is stl.BOT: + return + elif isinstance(psi, NaryOpSTL): for j, _ in enumerate(psi.args): yield focus.args[j] else: