added test scaffolds

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-02-19 15:48:28 -08:00
parent 08bc671401
commit ed7b084bd1
4 changed files with 39 additions and 18 deletions

View file

@ -1,4 +1,4 @@
from stl.utils import terms_lens, lineq_lens, walk, tree, and_or_lens from stl.utils import terms_lens, lineq_lens, walk, and_or_lens
from stl.utils import alw, env, andf, orf from stl.utils import alw, env, andf, orf
from stl.ast import dt_sym, t_sym from stl.ast import dt_sym, t_sym
from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred

View file

@ -5,17 +5,17 @@ import unittest
from sympy import Symbol from sympy import Symbol
ex1 = ('x1 > 2', stl.LinEq( ex1 = ('x1 > 2', stl.LinEq(
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),), (stl.Var(1, Symbol("x1"), stl.t_sym),),
">", ">",
2.0 2.0
)) ))
ex1_ = ('x1 > a?', stl.LinEq( ex1_ = ('x1 > a?', stl.LinEq(
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),), (stl.Var(1, Symbol("x1"), stl.t_sym),),
">", ">",
Symbol("a?") Symbol("a?")
)) ))
ex1__ = ('x1', stl.AtomicPred('x1')) ex1__ = ('x1', stl.AtomicPred('x1', stl.t_sym))
i1 = stl.Interval(0., 1.) i1 = stl.Interval(0., 1.)
i1_ = stl.Interval(0., Symbol("b?")) i1_ = stl.Interval(0., Symbol("b?"))

View file

@ -28,3 +28,31 @@ class TestSTLUtils(unittest.TestCase):
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set()) self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set())
self.assertEqual(phi, phi2) self.assertEqual(phi, phi2)
def test_walk(self):
raise NotImplementedError
def test_type_pred(self):
raise NotImplementedError
def test_ast_lens(self):
raise NotImplementedError
def test_terms_lens(self):
raise NotImplementedError
def test_f_neg_or_canonical_form(self):
raise NotImplementedError
def test_to_from_mtl(self):
raise NotImplementedError
def test_get_polarity(self):
raise NotImplementedError
def test_canonical_polarity(self):
raise NotImplementedError

View file

@ -22,10 +22,6 @@ def walk(phi:STL, bfs:bool=False) -> STL_Generator:
children.extend(node.children()) children.extend(node.children())
def tree(phi:STL) -> Dict[STL, STL]:
return {x:set(x.children()) for x in walk(phi) if x.children()}
def type_pred(*args:List[Type]) -> Mapping[Type, bool]: def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
ast_types = set(args) ast_types = set(args)
return lambda x: type(x) in ast_types return lambda x: type(x) in ast_types
@ -80,16 +76,6 @@ def param_lens(phi:STL) -> Lens:
focus_lens=focus_lens).filter_(is_sym) focus_lens=focus_lens).filter_(is_sym)
def symbol_lens(phi:STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf):
spacial = [lens().const] + lens().terms.each_().id.get_all()
temporal = [lens().interval[0], lens().interval[1]]
return spacial if isinstance(leaf, LinEq) else temp
return ast_lens(phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym)
def set_params(stl_or_lens, val) -> STL: def set_params(stl_or_lens, val) -> STL:
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens) l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens)
return l.modify(lambda x: val.get(x, val.get(str(x), x))) return l.modify(lambda x: val.get(x, val.get(str(x), x)))
@ -131,6 +117,13 @@ def from_mtl(phi:MTL, ap_map:Dict[AtomicPred, LinEq]) -> STL:
return focus.modify(ap_map.get) return focus.modify(ap_map.get)
def get_polarity(phi, traces=None):
raise NotImplementedError
def canonical_polarity(phi, traces=None):
raise NotImplementedError
# EDSL # EDSL
def alw(phi, *, lo, hi): def alw(phi, *, lo, hi):