diff --git a/robustness.py b/robustness.py new file mode 100644 index 0000000..994303f --- /dev/null +++ b/robustness.py @@ -0,0 +1,64 @@ +from functools import singledispatch +from operator import sub, add + +from lenses import lens + +import stl.ast + + +@singledispatch +def pointwise_robustness(stl): + raise NotImplementedError + + +@pointwise_robustness.register(stl.Or) +def _(stl): + return lambda x, t: max(pointwise_robustness(arg)(x, t) for arg in stl.args) + + +@pointwise_robustness.register(stl.And) +def _(stl): + return lambda x, t: min(pointwise_robustness(arg)(x, t) for arg in stl.args) + + +@pointwise_robustness.register(stl.F) +def _(stl): + lo, hi = stl.interval + return lambda x, t: max(pointwise_robustness(stl.arg)(x, t + t2) + for t2 in x[lo:hi].index) + + +@pointwise_robustness.register(stl.G) +def _(stl): + lo, hi = stl.interval + return lambda x, t: min(pointwise_robustness(stl.arg)(x, t + t2) + for t2 in x[lo:hi].index) + + +@pointwise_robustness.register(stl.Neg) +def _(stl): + return lambda x, t: -pointwise_robustness(arg)(x, t) + + +op_lookup = { + ">": sub, + ">=": sub, + "<": add, + "<=": add, + "=": lambda a, b: -abs(a - b), +} + + +@pointwise_robustness.register(stl.LinEq) +def _(stl): + op = op_lookup[stl.op] + return lambda x, t: op(eval_terms(stl, x, t), stl.const) + + +def eval_terms(lineq, x, t): + psi = lens(lineq).terms.each_().modify(eval_term(x, t)) + return sum(psi.terms) + + +def eval_term(x, t): + return lambda term: term.coeff*x[term.id.name][t] diff --git a/utils.py b/utils.py index dd705ea..1b0019e 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,5 @@ +from collections import deque + from lenses import lens import funcy as fn