diff --git a/stl/__init__.py b/stl/__init__.py index 8d25cd4..b4b073b 100644 --- a/stl/__init__.py +++ b/stl/__init__.py @@ -1,6 +1,6 @@ from stl.utils import terms_lens, lineq_lens, walk, tree, and_or_lens from stl.ast import dt_sym, t_sym -from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var +from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var, AtomicPred from stl.parser import parse from stl.synth import lex_param_project from stl.boolean_eval import pointwise_sat diff --git a/stl/ast.py b/stl/ast.py index 7a424ac..0e80291 100644 --- a/stl/ast.py +++ b/stl/ast.py @@ -13,6 +13,14 @@ str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w} dt_sym = Symbol('dt', positive=True) t_sym = Symbol('t', positive=True) +class AtomicPred(namedtuple("AP", ["id"])): + def __repr__(self): + return "{}".format(self.id) + + def children(self): + return [] + + class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])): def __repr__(self): n = len(self.terms) diff --git a/stl/boolean_eval.py b/stl/boolean_eval.py index 2912f7f..f1f0772 100644 --- a/stl/boolean_eval.py +++ b/stl/boolean_eval.py @@ -5,6 +5,7 @@ from functools import singledispatch import operator as op import numpy as np +import sympy as smp from lenses import lens import stl.ast @@ -52,6 +53,11 @@ op_lookup = { } +@pointwise_sat.register(stl.AtomicPred) +def _(stl): + return lambda x, t: x[term.id][t] + + @pointwise_sat.register(stl.LinEq) def _(stl): op = op_lookup[stl.op] diff --git a/stl/parser.py b/stl/parser.py index 8d3bc75..81ccd51 100644 --- a/stl/parser.py +++ b/stl/parser.py @@ -21,7 +21,7 @@ from sympy import Symbol, Number from stl import ast STL_GRAMMAR = Grammar(u''' -phi = (g / f / lineq / or / and / paren_phi) +phi = (g / f / lineq / AP / or / and / paren_phi) paren_phi = "(" __ phi __ ")" @@ -47,6 +47,8 @@ time = prime / time_index time_index = "[" "t" __ pm __ const "]" prime = "'" +AP = ~r"[a-zA-z\d]+" + pm = "+" / "-" dt = "dt" unbound = id "?" @@ -145,6 +147,9 @@ class STLVisitor(NodeVisitor): def visit_pm(self, node, _): return Number(1) if node.text == "+" else Number(-1) + def visit_AP(self, node, _): + return ast.AtomicPred(node.text) + def parse(stl_str:str, rule:str="phi") -> "STL": return STLVisitor().visit(STL_GRAMMAR[rule].parse(stl_str))