first hand tested version of smooth_robustness w.o. automatic canonical form

This commit is contained in:
Marcell Vazquez-Chanlatte 2016-12-15 20:23:56 -08:00
parent cd3fc97eab
commit 39dcc82912
2 changed files with 30 additions and 16 deletions

View file

@ -13,7 +13,7 @@ from stl.ast import t_sym
from stl.utils import walk from stl.utils import walk
from stl.robustness import op_lookup from stl.robustness import op_lookup
Param = namedtuple("Param", ["L", "h", "B", "id_map", "eps"]) Param = namedtuple("Param", ["L", "h", "B", "eps"])
@singledispatch @singledispatch
def node_base(_, _1, _2): def node_base(_, _1, _2):
@ -21,12 +21,12 @@ def node_base(_, _1, _2):
@node_base.register(stl.ast.Or) @node_base.register(stl.ast.Or)
def node_base(_, eps, _1): def _(_, eps, _1):
return len(stl.args)**(1/eps) return len(stl.args)**(1/eps)
@node_base.register(stl.ast.F) @node_base.register(stl.ast.F)
def node_base(_, eps, L): def _(_, eps, L):
lo, hi = stl.interval lo, hi = stl.interval
return sym.ceil((hi - lo)*L/eps)**(2/eps) return sym.ceil((hi - lo)*L/eps)**(2/eps)
@ -40,13 +40,17 @@ def admissible_params(phi, eps, L):
B = max(node_base(n, eps, L) for n in walk(phi)), B = max(node_base(n, eps, L) for n in walk(phi)),
return B, h return B, h
def new_symbol_set(ss):
indices = set(ss[id_map].keys())
non_indicies = set(v.name for k, v in ss.items() if v != "id_map")
return indices | non_indicies
def symbolic_params(phi, eps, L): def symbolic_params(phi, eps, L):
return Param( return Param(
L=sym.Symbol("L"), L=sym.Dummy("L"),
h=sym.Symbol("h"), h=sym.Dummy("h"),
B="B", B=sym.Dummy("B"),
id_map={n:i for i, n in enumerate(walk(phi))}, eps=sym.Dummy("eps"),
eps=sym.symbol("eps")
) )
@ -87,9 +91,8 @@ def eval_term(term):
@alpha.register(stl.LinEq) @alpha.register(stl.LinEq)
def _(phi, p): def _(phi, p):
op = op_lookup[phi.op] op = op_lookup[phi.op]
B = eps_to_base(eps/depth, N)
x = op(eval_terms(phi), phi.const) x = op(eval_terms(phi), phi.const)
return B**x return p.B**x
@alpha.register(stl.Neg) @alpha.register(stl.Neg)
@ -99,21 +102,21 @@ def _(phi, p):
@alpha.register(stl.Or) @alpha.register(stl.Or)
def _(phi, p): def _(phi, p):
return sum(alpha(psi, p) for psi in psi in phi.args) return sum(alpha(psi, p) for psi in phi.args)
def F_params(phi, p, r): def F_params(phi, p, r):
hi, lo = phi.interval hi, lo = phi.interval
N = sym.ceiling((hi - lo) / p.h) N = sym.ceiling((hi - lo) / p.h)
i = sym.Symbol("i_{}".format(p.id_map[phi]))
x = lambda k: r.subs({t_sym: t_sym+k+lo}) x = lambda k: r.subs({t_sym: t_sym+k+lo})
i = sym.Dummy("i")
return N, i, x return N, i, x
@alpha.register(stl.F) @alpha.register(stl.F)
def _(phi, p): def _(phi, p):
N, i, x = F_params(phi, p, alpha(phi.arg, p)) N, i, x = F_params(phi, p, alpha(phi.arg, p))
x_ij = sym.sqrt(p.B**(L*h)*x(i)*x(i+1)) x_ij = sym.sqrt(p.B**(p.L*p.h)*x(i)*x(i+1))
return sym.summation(x_ij, (i, 0, N-1)) return sym.summation(x_ij, (i, 0, N-1))
# Beta implementation # Beta implementation
@ -131,7 +134,7 @@ def _(phi, p):
@beta.register(stl.Or) @beta.register(stl.Or)
def _(phi, p): def _(phi, p):
return alpha(phi)/len(phi.args) return alpha(phi, p)/len(phi.args)
@beta.register(stl.F) @beta.register(stl.F)

View file

@ -66,11 +66,22 @@ def terms_lens(phi:"STL", bind=True) -> lens:
def param_lens(phi): def param_lens(phi):
is_sym = lambda x: isinstance(x, sympy.Symbol) is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf): def focus_lens(leaf):
return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] return [lens().const, lens().id] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens).filter_(is_sym) return ast_lens(phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym)
def symbol_lens(phi):
is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf):
spacial = [lens().const] + lens().terms.each_().id.get_all()
temporal = [lens().interval[0], lens().interval[1]]
return spacial if isinstance(leaf, LinEq) else temp
return ast_lens(phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym)
def set_params(stl_or_lens, val): def set_params(stl_or_lens, val):
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens) l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens)
return l.modify(lambda x: val[str(x)] if str(x) in val else x) return l.modify(lambda x: val[str(x)] if str(x) in val else x)