From 07bf9f50e42cafda338937dbe8ce0a51f33479b1 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Thu, 20 Apr 2017 21:08:44 -0700 Subject: [PATCH] cache pointwise_sat call --- stl/__init__.py | 3 ++- stl/boolean_eval.py | 21 +++++++++++++++++---- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/stl/__init__.py b/stl/__init__.py index a64d0fa..a4e7b1e 100644 --- a/stl/__init__.py +++ b/stl/__init__.py @@ -1,7 +1,8 @@ from stl.utils import terms_lens, lineq_lens, walk, and_or_lens from stl.utils import alw, env, andf, orf from stl.ast import dt_sym, t_sym, TOP, BOT -from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred +from stl.ast import (LinEq, Interval, NaryOpSTL, Or, And, F, G, + ModalOp, Neg, Var, AtomicPred, Until) from stl.parser import parse from stl.fastboolean_eval import pointwise_sat from stl.synth import lex_param_project diff --git a/stl/boolean_eval.py b/stl/boolean_eval.py index 502c008..e1884fc 100644 --- a/stl/boolean_eval.py +++ b/stl/boolean_eval.py @@ -16,31 +16,44 @@ def pointwise_sat(stl): @pointwise_sat.register(stl.Or) def _(stl): - return lambda x, t: any(pointwise_sat(arg)(x, t) for arg in stl.args) + fs = [pointwise_sat(arg) for arg in stl.args] + return lambda x, t: any(f(x, t) for f in fs) @pointwise_sat.register(stl.And) def _(stl): - return lambda x, t: all(pointwise_sat(arg)(x, t) for arg in stl.args) + fs = [pointwise_sat(arg) for arg in stl.args] + return lambda x, t: all(f(x, t) for f in fs) + + +@pointwise_sat.register(stl.Until) +def _(stl): + def _until(x, t): + phi = (pointwise_sat(phi)(x, t) for t in x.index) + return lambda x, t: any((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1])) + for t2 in x[lo:hi].index)) @pointwise_sat.register(stl.F) def _(stl): lo, hi = stl.interval - return lambda x, t: any((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1])) + f = pointwise_sat(stl.arg) + return lambda x, t: any((f(x, min(t + t2, x.index[-1])) for t2 in x[lo:hi].index)) @pointwise_sat.register(stl.G) def _(stl): lo, hi = stl.interval + f = pointwise_sat(stl.arg) return lambda x, t: all((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1])) for t2 in x[lo:hi].index)) @pointwise_sat.register(stl.Neg) def _(stl): - return lambda x, t: not pointwise_sat(stl.arg)(x, t) + f = pointwise_sat(stl.arg) + return lambda x, t: not f(x, t) op_lookup = {