start adding Types for STL utilities

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-01-03 10:12:00 -08:00
parent a5be0b17bb
commit 788e6ee58e
2 changed files with 32 additions and 16 deletions

12
stl/types.py Normal file
View file

@ -0,0 +1,12 @@
import typing
import stl.ast as ast
ML = typing.Union[ast.AtomicPred, ast.NaryOpSTL, ast.Neg]
SL = typing.Union[ast.LinEq, ML]
STL = typing.Union[SL, ast.ModalOp]
MTL = typing.Union[ML, ast.ModalOp]
PSTL = typing.NewType("PSTL", STL)
STL_Generator = typing.Generator[STL, None, STL]

View file

@ -1,3 +1,4 @@
from typing import List, Type, Dict, Mapping, T
from collections import deque from collections import deque
from lenses import lens, Lens from lenses import lens, Lens
@ -5,28 +6,30 @@ import funcy as fn
import sympy import sympy
import stl.ast import stl.ast
from stl.ast import LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg,
AtomicPred)
from stl.types import STL, STL_Generator, MTL
def walk(stl, bfs=False): def walk(phi:STL, bfs:bool=False) -> STL_Generator:
"""Walks Ast. Defaults to DFS unless BFS flag is set.""" """Walks Ast. Defaults to DFS unless BFS flag is set."""
pop = deque.popleft if bfs else deque.pop pop = deque.popleft if bfs else deque.pop
children = deque([stl]) children = deque([phi])
while len(children) != 0: while len(children) != 0:
node = pop(children) node = pop(children)
yield node yield node
children.extend(node.children()) children.extend(node.children())
def tree(stl): def tree(phi:STL) -> Dict[STL, STL]:
return {x:set(x.children()) for x in walk(stl) if x.children()} return {x:set(x.children()) for x in walk(phi) if x.children()}
def type_pred(*args): def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
ast_types = set(args) ast_types = set(args)
return lambda x: type(x) in ast_types return lambda x: type(x) in ast_types
def _child_lens(psi, focus): def _child_lens(psi:STL, focus:Lens) -> STL_Generator:
if psi is None: if psi is None:
return return
if isinstance(psi, NaryOpSTL): if isinstance(psi, NaryOpSTL):
@ -36,7 +39,8 @@ def _child_lens(psi, focus):
yield focus.arg yield focus.arg
def ast_lens(phi:"STL", bind=True, *, pred, focus_lens=None) -> lens: def ast_lens(phi:STL, bind:bool=True, *,
pred:Mapping[T, bool], focus_lens:Lens=None) -> Lens:
if focus_lens is None: if focus_lens is None:
focus_lens = lambda x: [lens()] focus_lens = lambda x: [lens()]
tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens))) tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens)))
@ -61,11 +65,11 @@ lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq))
AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred)) AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred))
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or)) and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or))
def terms_lens(phi:"STL", bind=True) -> lens: def terms_lens(phi:STL, bind:bool=True) -> Lens:
return lineq_lens(phi, bind).terms.each_() return lineq_lens(phi, bind).terms.each_()
def param_lens(phi): def param_lens(phi:STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol) is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf): def focus_lens(leaf):
return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]] return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
@ -74,7 +78,7 @@ def param_lens(phi):
focus_lens=focus_lens).filter_(is_sym) focus_lens=focus_lens).filter_(is_sym)
def symbol_lens(phi): def symbol_lens(phi:STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol) is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf): def focus_lens(leaf):
spacial = [lens().const] + lens().terms.each_().id.get_all() spacial = [lens().const] + lens().terms.each_().id.get_all()
@ -84,12 +88,12 @@ def symbol_lens(phi):
return ast_lens(phi, pred=type_pred(LinEq, F, G), return ast_lens(phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym) focus_lens=focus_lens).filter_(is_sym)
def set_params(stl_or_lens, val): def set_params(stl_or_lens, val) -> STL:
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens) l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens)
return l.modify(lambda x: val[str(x)] if str(x) in val else x) return l.modify(lambda x: val.get(x, val.get(str(x), x)))
def f_neg_or_canonical_form(phi): def f_neg_or_canonical_form(phi:STL) -> STL:
if isinstance(phi, LinEq): if isinstance(phi, LinEq):
return phi return phi
@ -112,7 +116,7 @@ def f_neg_or_canonical_form(phi):
raise NotImplementedError raise NotImplementedError
def to_mtl(phi): def to_mtl(phi:STL) -> MTL:
focus = lineq_lens(phi) focus = lineq_lens(phi)
to_ap = lambda i: stl.ast.AtomicPred("AP{}".format(i)) to_ap = lambda i: stl.ast.AtomicPred("AP{}".format(i))
ap_map = {to_ap(i): leq for i, leq in enumerate(focus.get_all())} ap_map = {to_ap(i): leq for i, leq in enumerate(focus.get_all())}
@ -120,6 +124,6 @@ def to_mtl(phi):
return focus.modify(lineq_map.get), ap_map return focus.modify(lineq_map.get), ap_map
def from_mtl(phi, ap_map): def from_mtl(phi:MTL, ap_map:Dict[AtomicPred, LinEq]) -> STL:
focus = AP_lens(phi) focus = AP_lens(phi)
return focus.modify(ap_map.get) return focus.modify(ap_map.get)