added test scaffolds
This commit is contained in:
parent
08bc671401
commit
ed7b084bd1
4 changed files with 39 additions and 18 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
from stl.utils import terms_lens, lineq_lens, walk, tree, and_or_lens
|
from stl.utils import terms_lens, lineq_lens, walk, and_or_lens
|
||||||
from stl.utils import alw, env, andf, orf
|
from stl.utils import alw, env, andf, orf
|
||||||
from stl.ast import dt_sym, t_sym
|
from stl.ast import dt_sym, t_sym
|
||||||
from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred
|
from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred
|
||||||
|
|
|
||||||
|
|
@ -5,17 +5,17 @@ import unittest
|
||||||
from sympy import Symbol
|
from sympy import Symbol
|
||||||
|
|
||||||
ex1 = ('x1 > 2', stl.LinEq(
|
ex1 = ('x1 > 2', stl.LinEq(
|
||||||
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),),
|
(stl.Var(1, Symbol("x1"), stl.t_sym),),
|
||||||
">",
|
">",
|
||||||
2.0
|
2.0
|
||||||
))
|
))
|
||||||
ex1_ = ('x1 > a?', stl.LinEq(
|
ex1_ = ('x1 > a?', stl.LinEq(
|
||||||
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),),
|
(stl.Var(1, Symbol("x1"), stl.t_sym),),
|
||||||
">",
|
">",
|
||||||
Symbol("a?")
|
Symbol("a?")
|
||||||
))
|
))
|
||||||
|
|
||||||
ex1__ = ('x1', stl.AtomicPred('x1'))
|
ex1__ = ('x1', stl.AtomicPred('x1', stl.t_sym))
|
||||||
|
|
||||||
i1 = stl.Interval(0., 1.)
|
i1 = stl.Interval(0., 1.)
|
||||||
i1_ = stl.Interval(0., Symbol("b?"))
|
i1_ = stl.Interval(0., Symbol("b?"))
|
||||||
|
|
|
||||||
|
|
@ -28,3 +28,31 @@ class TestSTLUtils(unittest.TestCase):
|
||||||
|
|
||||||
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set())
|
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set())
|
||||||
self.assertEqual(phi, phi2)
|
self.assertEqual(phi, phi2)
|
||||||
|
|
||||||
|
def test_walk(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def test_type_pred(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def test_ast_lens(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def test_terms_lens(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
def test_f_neg_or_canonical_form(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_to_from_mtl(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_get_polarity(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def test_canonical_polarity(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
||||||
21
stl/utils.py
21
stl/utils.py
|
|
@ -22,10 +22,6 @@ def walk(phi:STL, bfs:bool=False) -> STL_Generator:
|
||||||
children.extend(node.children())
|
children.extend(node.children())
|
||||||
|
|
||||||
|
|
||||||
def tree(phi:STL) -> Dict[STL, STL]:
|
|
||||||
return {x:set(x.children()) for x in walk(phi) if x.children()}
|
|
||||||
|
|
||||||
|
|
||||||
def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
|
def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
|
||||||
ast_types = set(args)
|
ast_types = set(args)
|
||||||
return lambda x: type(x) in ast_types
|
return lambda x: type(x) in ast_types
|
||||||
|
|
@ -80,16 +76,6 @@ def param_lens(phi:STL) -> Lens:
|
||||||
focus_lens=focus_lens).filter_(is_sym)
|
focus_lens=focus_lens).filter_(is_sym)
|
||||||
|
|
||||||
|
|
||||||
def symbol_lens(phi:STL) -> Lens:
|
|
||||||
is_sym = lambda x: isinstance(x, sympy.Symbol)
|
|
||||||
def focus_lens(leaf):
|
|
||||||
spacial = [lens().const] + lens().terms.each_().id.get_all()
|
|
||||||
temporal = [lens().interval[0], lens().interval[1]]
|
|
||||||
return spacial if isinstance(leaf, LinEq) else temp
|
|
||||||
|
|
||||||
return ast_lens(phi, pred=type_pred(LinEq, F, G),
|
|
||||||
focus_lens=focus_lens).filter_(is_sym)
|
|
||||||
|
|
||||||
def set_params(stl_or_lens, val) -> STL:
|
def set_params(stl_or_lens, val) -> STL:
|
||||||
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens)
|
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens)
|
||||||
return l.modify(lambda x: val.get(x, val.get(str(x), x)))
|
return l.modify(lambda x: val.get(x, val.get(str(x), x)))
|
||||||
|
|
@ -131,6 +117,13 @@ def from_mtl(phi:MTL, ap_map:Dict[AtomicPred, LinEq]) -> STL:
|
||||||
return focus.modify(ap_map.get)
|
return focus.modify(ap_map.get)
|
||||||
|
|
||||||
|
|
||||||
|
def get_polarity(phi, traces=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def canonical_polarity(phi, traces=None):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
# EDSL
|
# EDSL
|
||||||
|
|
||||||
def alw(phi, *, lo, hi):
|
def alw(phi, *, lo, hi):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue