From 788e6ee58e55963ae45925cc74a6a92ad5869d52 Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Tue, 3 Jan 2017 10:12:00 -0800 Subject: [PATCH] start adding Types for STL utilities --- stl/types.py | 12 ++++++++++++ stl/utils.py | 36 ++++++++++++++++++++---------------- 2 files changed, 32 insertions(+), 16 deletions(-) create mode 100644 stl/types.py diff --git a/stl/types.py b/stl/types.py new file mode 100644 index 0000000..74e6c31 --- /dev/null +++ b/stl/types.py @@ -0,0 +1,12 @@ +import typing + +import stl.ast as ast + +ML = typing.Union[ast.AtomicPred, ast.NaryOpSTL, ast.Neg] +SL = typing.Union[ast.LinEq, ML] + +STL = typing.Union[SL, ast.ModalOp] +MTL = typing.Union[ML, ast.ModalOp] + +PSTL = typing.NewType("PSTL", STL) +STL_Generator = typing.Generator[STL, None, STL] diff --git a/stl/utils.py b/stl/utils.py index f7d9a1e..83628c5 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -1,3 +1,4 @@ +from typing import List, Type, Dict, Mapping, T from collections import deque from lenses import lens, Lens @@ -5,28 +6,30 @@ import funcy as fn import sympy import stl.ast -from stl.ast import LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg +from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg, + AtomicPred) +from stl.types import STL, STL_Generator, MTL -def walk(stl, bfs=False): +def walk(phi:STL, bfs:bool=False) -> STL_Generator: """Walks Ast. Defaults to DFS unless BFS flag is set.""" pop = deque.popleft if bfs else deque.pop - children = deque([stl]) + children = deque([phi]) while len(children) != 0: node = pop(children) yield node children.extend(node.children()) -def tree(stl): - return {x:set(x.children()) for x in walk(stl) if x.children()} +def tree(phi:STL) -> Dict[STL, STL]: + return {x:set(x.children()) for x in walk(phi) if x.children()} -def type_pred(*args): +def type_pred(*args:List[Type]) -> Mapping[Type, bool]: ast_types = set(args) return lambda x: type(x) in ast_types -def _child_lens(psi, focus): +def _child_lens(psi:STL, focus:Lens) -> STL_Generator: if psi is None: return if isinstance(psi, NaryOpSTL): @@ -36,7 +39,8 @@ def _child_lens(psi, focus): yield focus.arg -def ast_lens(phi:"STL", bind=True, *, pred, focus_lens=None) -> lens: +def ast_lens(phi:STL, bind:bool=True, *, + pred:Mapping[T, bool], focus_lens:Lens=None) -> Lens: if focus_lens is None: focus_lens = lambda x: [lens()] tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens))) @@ -61,11 +65,11 @@ lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq)) AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred)) and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or)) -def terms_lens(phi:"STL", bind=True) -> lens: +def terms_lens(phi:STL, bind:bool=True) -> Lens: return lineq_lens(phi, bind).terms.each_() -def param_lens(phi): +def param_lens(phi:STL) -> Lens: is_sym = lambda x: isinstance(x, sympy.Symbol) def focus_lens(leaf): return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] @@ -74,7 +78,7 @@ def param_lens(phi): focus_lens=focus_lens).filter_(is_sym) -def symbol_lens(phi): +def symbol_lens(phi:STL) -> Lens: is_sym = lambda x: isinstance(x, sympy.Symbol) def focus_lens(leaf): spacial = [lens().const] + lens().terms.each_().id.get_all() @@ -84,12 +88,12 @@ def symbol_lens(phi): return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens).filter_(is_sym) -def set_params(stl_or_lens, val): +def set_params(stl_or_lens, val) -> STL: l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens) - return l.modify(lambda x: val[str(x)] if str(x) in val else x) + return l.modify(lambda x: val.get(x, val.get(str(x), x))) -def f_neg_or_canonical_form(phi): +def f_neg_or_canonical_form(phi:STL) -> STL: if isinstance(phi, LinEq): return phi @@ -112,7 +116,7 @@ def f_neg_or_canonical_form(phi): raise NotImplementedError -def to_mtl(phi): +def to_mtl(phi:STL) -> MTL: focus = lineq_lens(phi) to_ap = lambda i: stl.ast.AtomicPred("AP{}".format(i)) ap_map = {to_ap(i): leq for i, leq in enumerate(focus.get_all())} @@ -120,6 +124,6 @@ def to_mtl(phi): return focus.modify(lineq_map.get), ap_map -def from_mtl(phi, ap_map): +def from_mtl(phi:MTL, ap_map:Dict[AtomicPred, LinEq]) -> STL: focus = AP_lens(phi) return focus.modify(ap_map.get)