diff --git a/stl/utils.py b/stl/utils.py index b5414e1..6c4e2ff 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -101,13 +101,14 @@ def implicit_validity_domain(phi, trace): def require_discretizable(func): @wraps(func) def _func(phi, dt, *args, **kwargs): - assert is_discretizable(phi, dt) + if 'horizon' not in kwargs: + assert is_discretizable(phi, dt) return func(phi, dt, *args, **kwargs) return _func -def scope(phi, dt, *, _t=0): +def scope(phi, dt, *, _t=0, horizon=oo): if isinstance(phi, Next): _t += dt elif isinstance(phi, (G, F)): @@ -115,32 +116,44 @@ def scope(phi, dt, *, _t=0): elif isinstance(phi, Until): _t += float('inf') - return max((scope(c, dt, _t=_t) for c in phi.children), default=_t) + _scope = max((scope(c, dt, _t=_t) for c in phi.children), default=_t) + return min(_scope, horizon) # Code to discretize a bounded STL formula @require_discretizable -def discretize(phi, dt, distribute=False): - phi = _discretize(phi, dt) +def discretize(phi, dt, distribute=False, *, horizon=None): + if horizon is None: + horizon = oo + + phi = _discretize(phi, dt, horizon) return _distribute_next(phi) if distribute else phi -def _discretize(phi, dt): +def _discretize(phi, dt, horizon): if isinstance(phi, (LinEq, AtomicPred, _Top, _Bot)): return phi - children = tuple(_discretize(arg, dt) for arg in phi.children) - if isinstance(phi, (And, Or)): - return bind(phi).args.set(children) - elif isinstance(phi, (Neg, Next)): - return bind(phi).arg.set(children[0]) + if not isinstance(phi, (F, G, Until)): + children = tuple(_discretize(arg, dt, horizon) for arg in phi.children) + if isinstance(phi, (And, Or)): + return bind(phi).args.set(children) + elif isinstance(phi, (Neg, Next)): + return bind(phi).arg.set(children[0]) + + raise NotImplementedError + + elif isinstance(phi, Until): + raise NotImplementedError # Only remaining cases are G and F - psi = children[0] - l, u = round(phi.interval.lower / dt), round(phi.interval.upper / dt) - psis = (next(psi, i) for i in range(l, u + 1)) + upper = min(phi.interval.upper, horizon) + l, u = round(phi.interval.lower / dt), round(upper / dt) + + psis = (next(_discretize(phi.arg, dt, horizon - i), i) + for i in range(l, u + 1)) opf = andf if isinstance(phi, G) else orf return opf(*psis) @@ -174,7 +187,6 @@ def is_discretizable(phi, dt): _interval_discretizable(c.interval, dt) for c in phi.walk() if isinstance(c, (F, G))) - # EDSL