start adding Types for STL utilities

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-01-03 10:12:00 -08:00
parent a5be0b17bb
commit 788e6ee58e
2 changed files with 32 additions and 16 deletions

View file

@ -1,3 +1,4 @@
from typing import List, Type, Dict, Mapping, T
from collections import deque
from lenses import lens, Lens
@ -5,28 +6,30 @@ import funcy as fn
import sympy
import stl.ast
from stl.ast import LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg
from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg,
AtomicPred)
from stl.types import STL, STL_Generator, MTL
def walk(stl, bfs=False):
def walk(phi:STL, bfs:bool=False) -> STL_Generator:
"""Walks Ast. Defaults to DFS unless BFS flag is set."""
pop = deque.popleft if bfs else deque.pop
children = deque([stl])
children = deque([phi])
while len(children) != 0:
node = pop(children)
yield node
children.extend(node.children())
def tree(stl):
return {x:set(x.children()) for x in walk(stl) if x.children()}
def tree(phi:STL) -> Dict[STL, STL]:
return {x:set(x.children()) for x in walk(phi) if x.children()}
def type_pred(*args):
def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
ast_types = set(args)
return lambda x: type(x) in ast_types
def _child_lens(psi, focus):
def _child_lens(psi:STL, focus:Lens) -> STL_Generator:
if psi is None:
return
if isinstance(psi, NaryOpSTL):
@ -36,7 +39,8 @@ def _child_lens(psi, focus):
yield focus.arg
def ast_lens(phi:"STL", bind=True, *, pred, focus_lens=None) -> lens:
def ast_lens(phi:STL, bind:bool=True, *,
pred:Mapping[T, bool], focus_lens:Lens=None) -> Lens:
if focus_lens is None:
focus_lens = lambda x: [lens()]
tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens)))
@ -61,11 +65,11 @@ lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq))
AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred))
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or))
def terms_lens(phi:"STL", bind=True) -> lens:
def terms_lens(phi:STL, bind:bool=True) -> Lens:
return lineq_lens(phi, bind).terms.each_()
def param_lens(phi):
def param_lens(phi:STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf):
return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
@ -74,7 +78,7 @@ def param_lens(phi):
focus_lens=focus_lens).filter_(is_sym)
def symbol_lens(phi):
def symbol_lens(phi:STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf):
spacial = [lens().const] + lens().terms.each_().id.get_all()
@ -84,12 +88,12 @@ def symbol_lens(phi):
return ast_lens(phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym)
def set_params(stl_or_lens, val):
def set_params(stl_or_lens, val) -> STL:
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens)
return l.modify(lambda x: val[str(x)] if str(x) in val else x)
return l.modify(lambda x: val.get(x, val.get(str(x), x)))
def f_neg_or_canonical_form(phi):
def f_neg_or_canonical_form(phi:STL) -> STL:
if isinstance(phi, LinEq):
return phi
@ -112,7 +116,7 @@ def f_neg_or_canonical_form(phi):
raise NotImplementedError
def to_mtl(phi):
def to_mtl(phi:STL) -> MTL:
focus = lineq_lens(phi)
to_ap = lambda i: stl.ast.AtomicPred("AP{}".format(i))
ap_map = {to_ap(i): leq for i, leq in enumerate(focus.get_all())}
@ -120,6 +124,6 @@ def to_mtl(phi):
return focus.modify(lineq_map.get), ap_map
def from_mtl(phi, ap_map):
def from_mtl(phi:MTL, ap_map:Dict[AtomicPred, LinEq]) -> STL:
focus = AP_lens(phi)
return focus.modify(ap_map.get)