diff --git a/stl/utils.py b/stl/utils.py index 88b0395..3a2b938 100644 --- a/stl/utils.py +++ b/stl/utils.py @@ -48,7 +48,7 @@ def _ast_lens(phi, *, pred, focus=lens(), focus_lens): psi = focus.get(state=phi) ret_lens = [focus.add_lens(l) for l in focus_lens(psi)] if pred(psi) else [] - if isinstance(psi, LinEq): + if isinstance(psi, (LinEq, stl.ast.AtomicPred)): return ret_lens child_lenses = list(_child_lens(psi, focus=focus)) @@ -58,6 +58,7 @@ def _ast_lens(phi, *, pred, focus=lens(), focus_lens): lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq)) +AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred)) and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or)) def terms_lens(phi:"STL", bind=True) -> lens: @@ -117,4 +118,9 @@ def to_mtl(phi): ap_map = {to_ap(i): leq for i, leq in enumerate(focus.get_all())} lineq_map = {v:k for k,v in ap_map.items()} return focus.modify(lineq_map.get), ap_map + + +def from_mtl(phi, ap_map): + focus = lineq_lens(phi) + return focus.modify(ap_map.get)