simplify parser + start removing lineq code

This commit is contained in:
Marcell Vazquez-Chanlatte 2018-09-06 01:19:06 -07:00
parent 7798fe679e
commit b9b10ac835
11 changed files with 143 additions and 247 deletions

View file

@ -72,9 +72,6 @@ class AST(object):
yield leaf.interval[0] yield leaf.interval[0]
if isinstance(leaf.interval[1], Param): if isinstance(leaf.interval[1], Param):
yield leaf.interval[1] yield leaf.interval[1]
elif isinstance(leaf, LinEq):
if isinstance(leaf.const, Param):
yield leaf.const
return set(fn.mapcat(get_params, self.walk())) return set(fn.mapcat(get_params, self.walk()))
@ -82,20 +79,10 @@ class AST(object):
phi = param_lens(self) phi = param_lens(self)
return phi.modify(lambda x: float(val.get(x, val.get(str(x), x)))) return phi.modify(lambda x: float(val.get(x, val.get(str(x), x))))
@property
def lineqs(self):
return set(lineq_lens.collect()(self))
@property @property
def atomic_predicates(self): def atomic_predicates(self):
return set(AP_lens.collect()(self)) return set(AP_lens.collect()(self))
@property
def var_names(self):
symbols = set(bind(self.lineqs).Each().terms.Each().collect())
symbols |= self.atomic_predicates
return set(bind(symbols).Each().id.collect())
def inline_context(self, context): def inline_context(self, context):
phi, phi2 = self, None phi, phi2 = self, None
@ -116,7 +103,7 @@ class _Top(AST):
__slots__ = () __slots__ = ()
def __repr__(self): def __repr__(self):
return "" return "1"
def __invert__(self): def __invert__(self):
return BOT return BOT
@ -126,7 +113,7 @@ class _Bot(AST):
__slots__ = () __slots__ = ()
def __repr__(self): def __repr__(self):
return "" return "0"
def __invert__(self): def __invert__(self):
return TOP return TOP
@ -192,7 +179,7 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
OP = "?" OP = "?"
def __repr__(self): def __repr__(self):
return f" {self.OP} ".join(f"({x})" for x in self.args) return "(" + f" {self.OP} ".join(f"{x}" for x in self.args) + ")"
@property @property
def children(self): def children(self):
@ -202,7 +189,7 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
class Or(NaryOpSTL): class Or(NaryOpSTL):
__slots__ = () __slots__ = ()
OP = "" OP = "|"
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
@ -212,7 +199,7 @@ class Or(NaryOpSTL):
class And(NaryOpSTL): class And(NaryOpSTL):
__slots__ = () __slots__ = ()
OP = "" OP = "&"
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
@ -224,7 +211,9 @@ class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST):
OP = '?' OP = '?'
def __repr__(self): def __repr__(self):
return f"{self.OP}{self.interval}({self.arg})" if self.interval.lower == 0 and self.interval.upper == float('inf'):
return f"{self.OP}{self.arg}"
return f"{self.OP}{self.interval}{self.arg}"
@property @property
def children(self): def children(self):
@ -233,7 +222,7 @@ class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST):
class F(ModalOp): class F(ModalOp):
__slots__ = () __slots__ = ()
OP = "" OP = "< >"
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
@ -242,7 +231,7 @@ class F(ModalOp):
class G(ModalOp): class G(ModalOp):
__slots__ = () __slots__ = ()
OP = "" OP = "[ ]"
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
@ -253,7 +242,7 @@ class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST):
__slots__ = () __slots__ = ()
def __repr__(self): def __repr__(self):
return f"({self.arg1}) U ({self.arg2})" return f"({self.arg1} U {self.arg2})"
@property @property
def children(self): def children(self):
@ -268,7 +257,7 @@ class Neg(namedtuple('Neg', ['arg']), AST):
__slots__ = () __slots__ = ()
def __repr__(self): def __repr__(self):
return f"¬({self.arg})" return f"~{self.arg}"
@property @property
def children(self): def children(self):
@ -283,7 +272,7 @@ class Next(namedtuple('Next', ['arg']), AST):
__slots__ = () __slots__ = ()
def __repr__(self): def __repr__(self):
return f"◯({self.arg})" return f"@{self.arg}"
@property @property
def children(self): def children(self):

View file

@ -25,9 +25,7 @@ def pointwise_sat(phi, dt=0.1):
ap_names = [z.id for z in phi.atomic_predicates] ap_names = [z.id for z in phi.atomic_predicates]
def _eval_stl(x, t=0): def _eval_stl(x, t=0):
evaluated = stl.utils.eval_lineqs(phi, x) evaluated = fn.project(x, ap_names)
evaluated.update(fn.project(x, ap_names))
return bool(eval_stl(phi, dt)(evaluated)[t]) return bool(eval_stl(phi, dt)(evaluated)[t])
return _eval_stl return _eval_stl
@ -107,6 +105,8 @@ def eval_stl_f(phi, dt):
def eval_stl_g(phi, dt): def eval_stl_g(phi, dt):
f = eval_stl(phi.arg, dt) f = eval_stl(phi.arg, dt)
a, b = phi.interval a, b = phi.interval
if b < a:
return lambda _: TRUE_TRACE
def process_intervals(x): def process_intervals(x):
# Need to add last interval # Need to add last interval

View file

@ -9,26 +9,11 @@ from lenses import bind
import stl.ast import stl.ast
oo = float('inf') oo = float('inf')
op_lookup = {
">": op.gt,
">=": op.ge,
"<": op.lt,
"<=": op.le,
"=": op.eq,
}
def get_times(x, tau, lo, hi):
def eval_terms(lineq, x, t):
terms = bind(lineq).terms.Each().collect()
return sum(eval_term(term, x, t) for term in terms)
def eval_term(term, x, t):
return float(term.coeff) * x[term.id][t]
def get_times(x, tau, lo=None, hi=None):
end = min(v.domain.end() for v in x.values()) end = min(v.domain.end() for v in x.values())
lo, hi = map(float, (lo, hi))
hi = hi + tau if hi + tau <= end else end hi = hi + tau if hi + tau <= end else end
lo = lo + tau if lo + tau <= end else end lo = lo + tau if lo + tau <= end else end
@ -112,10 +97,3 @@ def pointwise_satf_top(_):
@pointwise_satf.register(type(stl.BOT)) @pointwise_satf.register(type(stl.BOT))
def pointwise_satf_bot(_): def pointwise_satf_bot(_):
return lambda _, t: bitarray([False] * len(t)) return lambda _, t: bitarray([False] * len(t))
@pointwise_satf.register(stl.LinEq)
def pointwise_satf_lineq(stl):
def op(a):
return op_lookup[stl.op](a, stl.const)
return lambda x, t: bitarray(op(eval_terms(stl, x, tau)) for tau in t)

View file

@ -4,19 +4,21 @@ from hypothesis_cfg import ContextFreeGrammarStrategy
import stl import stl
GRAMMAR = { GRAMMAR = {
'phi': (('Unary', '(', 'phi', ')'), ('(', 'phi', ')', 'Binary', '(', 'phi', 'phi': (
')'), ('AP', ), ('LINEQ', ), ('', ), ('Unary', 'phi'),
('', )), ('(', 'phi', 'Binary', 'phi', ')'),
('AP', ), ('0', ), ('1', )
),
'Unary': (('~', ), ('G', 'Interval'), ('F', 'Interval'), ('X', )), 'Unary': (('~', ), ('G', 'Interval'), ('F', 'Interval'), ('X', )),
'Interval': (('', ), ('[1, 3]', )), 'Interval': (('', ), ('[1, 3]', )),
'Binary': ((' | ', ), (' & ', ), (' U ', ), (' -> ', ), (' <-> ',), 'Binary': (
(' ^ ',)), (' | ', ), (' & ', ), (' -> ', ), (' <-> ',), (' ^ ',),
'AP': (('AP1', ), ('AP2', ), ('AP3', ), ('AP4', ), ('AP5', )), (' U ',),
'LINEQ': (('x > 4', ), ('y < 2', ), ('y >= 3', ), ('x + 2.0y >= 2', )), ),
'AP': (('ap1', ), ('ap2', ), ('ap3', ), ('ap4', ), ('ap5', )),
} }
SignalTemporalLogicStrategy = st.builds(lambda term: stl.parse(''.join(term)), SignalTemporalLogicStrategy = st.builds(
ContextFreeGrammarStrategy( lambda term: stl.parse(''.join(term)),
GRAMMAR, ContextFreeGrammarStrategy(GRAMMAR, max_length=14, start='phi')
max_length=14, )
start='phi'))

View file

@ -2,58 +2,53 @@
# TODO: allow multiplication to be distributive # TODO: allow multiplication to be distributive
# TODO: support variables on both sides of ineq # TODO: support variables on both sides of ineq
import operator as op
from functools import partialmethod from functools import partialmethod, reduce
from lenses import bind from lenses import bind
from parsimonious import Grammar, NodeVisitor from parsimonious import Grammar, NodeVisitor
from stl import ast from stl import ast
from stl.utils import iff, implies, xor from stl.utils import iff, implies, xor, timed_until
STL_GRAMMAR = Grammar(u''' STL_GRAMMAR = Grammar(u'''
phi = (timed_until / until / neg / next / g / f / lineq / AP / or / and phi = (neg / paren_phi / next / bot / top
/ implies / xor / iff / paren_phi / bot / top) / xor_outer / iff_outer / implies_outer / and_outer / or_outer
/ timed_until / until / g / f / AP)
paren_phi = "(" __ phi __ ")" paren_phi = "(" __ phi __ ")"
neg = ("~" / "¬") __ phi
next = ("@" / "X") __ phi
or = paren_phi _ ("" / "or" / "|") _ (or / paren_phi) and_outer = "(" __ and_inner __ ")"
and = paren_phi _ ("" / "and" / "&") _ (and / paren_phi) and_inner = (phi __ ("" / "and" / "&") __ and_inner) / phi
implies = paren_phi _ ("" / "->") _ (and / paren_phi)
iff = paren_phi _ ("" / "<->" / "iff") _ (and / paren_phi)
xor = paren_phi _ ("" / "^" / "xor") _ (and / paren_phi)
neg = ("~" / "¬") paren_phi or_outer = "(" __ or_inner __ ")"
next = next_sym paren_phi or_inner = (phi __ ("" / "or" / "|") __ or_inner) / phi
f = F interval? phi
g = G interval? phi
until = paren_phi __ U __ paren_phi
timed_until = paren_phi __ U interval __ paren_phi
next_sym = "X" / "" implies_outer = "(" __ implies_inner __ ")"
F = "F" / "" implies_inner = (phi __ ("" / "->") __ implies_inner) / phi
G = "G" / ""
U = "U"
iff_outer = "(" __ iff_inner __ ")"
iff_inner = (phi __ ("" / "<->" / "iff") __ iff_inner) / phi
xor_outer = "(" __ xor_inner __ ")"
xor_inner = (phi __ ("" / "^" / "xor") __ xor_inner) / phi
f = ("< >" / "F") interval? __ phi
g = ("[ ]" / "G") interval? __ phi
until = "(" __ phi _ "U" _ phi __ ")"
timed_until = "(" __ phi _ "U" interval _ phi __ ")"
interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]" interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]"
const_or_unbound = unbound / "inf" / const const_or_unbound = const / "inf" / id
lineq = terms _ op _ const_or_unbound AP = ~r"[a-z\d]+"
term = const? var
terms = (term __ pm __ terms) / term
var = id bot = "0"
AP = ~r"[a-zA-z\d]+" top = "1"
bot = "" id = ~r"[a-z\d]+"
top = ""
pm = "+" / "-"
dt = "dt"
unbound = id "?"
id = ~r"[a-zA-z\d]+"
const = ~r"[-+]?(\d*\.\d+|\d+)" const = ~r"[-+]?(\d*\.\d+|\d+)"
op = ">=" / "<=" / "<" / ">" / "="
_ = ~r"\s"+ _ = ~r"\s"+
__ = ~r"\s"* __ = ~r"\s"*
EOL = "\\n" EOL = "\\n"
@ -67,6 +62,32 @@ class STLVisitor(NodeVisitor):
super().__init__() super().__init__()
self.default_interval = ast.Interval(0.0, H) self.default_interval = ast.Interval(0.0, H)
def binop_inner(self, _, children):
if isinstance(children[0], ast.AST):
return children
((left, _, _, _, right),) = children
return [left] + right
def binop_outer(self, _, children, *, binop):
return reduce(binop, children[2])
def visit_const_or_unbound(self, node, children):
child = children[0]
return ast.Param(child) if isinstance(child, str) else float(node.text)
visit_and_inner = binop_inner
visit_iff_inner = binop_inner
visit_implies_inner = binop_inner
visit_or_inner = binop_inner
visit_xor_inner = binop_inner
visit_and_outer = partialmethod(binop_outer, binop=op.and_)
visit_iff_outer = partialmethod(binop_outer, binop=iff)
visit_implies_outer = partialmethod(binop_outer, binop=implies)
visit_or_outer = partialmethod(binop_outer, binop=op.or_)
visit_xor_outer = partialmethod(binop_outer, binop=xor)
def generic_visit(self, _, children): def generic_visit(self, _, children):
return children return children
@ -83,80 +104,41 @@ class STLVisitor(NodeVisitor):
return ast.TOP return ast.TOP
def visit_interval(self, _, children): def visit_interval(self, _, children):
_, _, (left, ), _, _, _, (right, ), _, _ = children _, _, left, _, _, _, right, _, _ = children
left = left if left != [] else float("inf")
right = right if right != [] else float("inf")
return ast.Interval(left, right) return ast.Interval(left, right)
def get_text(self, node, _): def get_text(self, node, _):
return node.text return node.text
def visit_unbound(self, node, _):
return ast.Param(node.text)
visit_op = get_text visit_op = get_text
def unary_temp_op_visitor(self, _, children, op): def unary_temp_op_visitor(self, _, children, op):
_, i, phi = children _, i, _, phi = children
i = self.default_interval if not i else i[0] i = self.default_interval if not i else i[0]
return op(i, phi) return op(i, phi)
def binop_visitor(self, _, children, op):
phi1, _, _, _, (phi2, ) = children
argL = list(phi1.args) if isinstance(phi1, op) else [phi1]
argR = list(phi2.args) if isinstance(phi2, op) else [phi2]
return op(tuple(argL + argR))
def sugar_binop_visitor(self, _, children, op):
phi1, _, _, _, (phi2, ) = children
return op(phi1, phi2)
visit_f = partialmethod(unary_temp_op_visitor, op=ast.F) visit_f = partialmethod(unary_temp_op_visitor, op=ast.F)
visit_g = partialmethod(unary_temp_op_visitor, op=ast.G) visit_g = partialmethod(unary_temp_op_visitor, op=ast.G)
visit_or = partialmethod(binop_visitor, op=ast.Or)
visit_and = partialmethod(binop_visitor, op=ast.And)
visit_xor = partialmethod(sugar_binop_visitor, op=xor)
visit_iff = partialmethod(sugar_binop_visitor, op=iff)
visit_implies = partialmethod(sugar_binop_visitor, op=implies)
def visit_until(self, _, children): def visit_until(self, _, children):
phi1, _, _, _, phi2 = children _, _, phi1, _, _, _, phi2, _, _ = children
return ast.Until(phi1, phi2) return ast.Until(phi1, phi2)
def visit_timed_until(self, _, children):
_, _, phi1, _, _, itvl, _, phi2, _, _ = children
return timed_until(phi1, phi2, itvl.lower, itvl.upper)
def visit_id(self, name, _): def visit_id(self, name, _):
return name.text return name.text
def visit_const(self, const, children):
return float(const.text)
def visit_term(self, _, children):
coeffs, iden = children
c = coeffs[0] if coeffs else 1
return ast.Var(coeff=c, id=iden)
def visit_terms(self, _, children):
if isinstance(children[0], list):
term, _1, sgn, _2, terms = children[0]
terms = bind(terms)[0].coeff * sgn
return [term] + terms
else:
return children
def visit_lineq(self, _, children):
terms, _1, op, _2, const = children
return ast.LinEq(tuple(terms), op, const[0])
def visit_pm(self, node, _):
return 1 if node.text == "+" else -1
def visit_AP(self, *args): def visit_AP(self, *args):
return ast.AtomicPred(self.visit_id(*args)) return ast.AtomicPred(self.visit_id(*args))
def visit_neg(self, _, children): def visit_neg(self, _, children):
return ~children[1] return ~children[2]
def visit_next(self, _, children): def visit_next(self, _, children):
return ast.Next(children[1]) return ast.Next(children[2])
def parse(stl_str: str, rule: str = "phi", H=oo) -> "STL": def parse(stl_str: str, rule: str = "phi", H=oo) -> "STL":

View file

@ -24,30 +24,7 @@ def test_identities(phi):
assert (phi | phi) | phi == phi | (phi | phi) assert (phi | phi) | phi == phi | (phi | phi)
assert ~~phi == phi assert ~~phi == phi
def test_lineqs_unittest():
phi = stl.parse('(G[0, 1](x + y > a?)) & (F[1,2](z - x > 0))')
assert len(phi.lineqs) == 2
assert phi.lineqs == {stl.parse('x + y > a?'), stl.parse('z - x > 0')}
phi = stl.parse('(G[0, 1](x + y > a?)) U (F[1,2](z - x > 0))')
assert len(phi.lineqs) == 2
assert phi.lineqs == {stl.parse('x + y > a?'), stl.parse('z - x > 0')}
phi = stl.parse('G(⊥)')
assert phi.lineqs == set()
phi = stl.parse('F()')
assert phi.lineqs == set()
def test_walk(): def test_walk():
phi = stl.parse( phi = stl.parse(
'((G[0, 1](x + y > a?)) & (F[1,2](z - x > 0))) | ((X(AP1)) U (AP2))') '(([ ][0, 1] ap1 & < >[1,2] ap2) | (@ap1 U ap2))')
assert len(list((~phi).walk())) == 11 assert len(list((~phi).walk())) == 11
def test_var_names():
phi = stl.parse(
'((G[0, 1](x + y > a?)) & (F[1,2](z - x > 0))) | ((X(AP1)) U (AP2))')
assert phi.var_names == {'x', 'y', 'z', 'x', 'AP1', 'AP2'}

View file

@ -27,17 +27,19 @@ x = {
traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)], domain=(0, 10)), traces.TimeSeries([(0, 1), (0.1, 1), (0.2, 4)], domain=(0, 10)),
"y": "y":
traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)], domain=(0, 10)), traces.TimeSeries([(0, 2), (0.1, 4), (0.2, 2)], domain=(0, 10)),
"AP1": "ap1":
traces.TimeSeries([(0, True), (0.1, True), (0.2, False)], domain=(0, 10)), traces.TimeSeries([(0, True), (0.1, True), (0.2, False)], domain=(0, 10)),
"AP2": "ap2":
traces.TimeSeries([(0, False), (0.2, True), (0.5, False)], domain=(0, 10)), traces.TimeSeries([(0, False), (0.2, True), (0.5, False)], domain=(0, 10)),
"AP3": "ap3":
traces.TimeSeries([(0, True), (0.1, True), (0.3, False)], domain=(0, 10)), traces.TimeSeries([(0, True), (0.1, True), (0.3, False)], domain=(0, 10)),
"AP4": "ap4":
traces.TimeSeries( traces.TimeSeries(
[(0, False), (0.1, False), (0.3, False)], domain=(0, 10)), [(0, False), (0.1, False), (0.3, False)], domain=(0, 10)),
"AP5": "ap5":
traces.TimeSeries([(0, False), (0.1, False), (0.3, True)], domain=(0, 10)), traces.TimeSeries([(0, False), (0.1, False), (0.3, True)], domain=(0, 10)),
"ap6":
traces.TimeSeries([(0, True)], domain=(0, 10)),
} }
@ -47,32 +49,32 @@ def test_eval_smoke_tests(phi):
stl_eval10 = stl.boolean_eval.pointwise_sat(~stl.ast.Next(phi)) stl_eval10 = stl.boolean_eval.pointwise_sat(~stl.ast.Next(phi))
assert stl_eval9(x, 0) != stl_eval10(x, 0) assert stl_eval9(x, 0) != stl_eval10(x, 0)
phi4 = stl.parse('~(AP4)') phi4 = stl.parse('~ap4')
stl_eval11 = stl.boolean_eval.pointwise_sat(phi4) stl_eval11 = stl.boolean_eval.pointwise_sat(phi4)
assert stl_eval11(x, 0) assert stl_eval11(x, 0)
phi5 = stl.parse('G[0.1, 0.03](~(AP4))') phi5 = stl.parse('G[0.1, 0.03] ~ap4')
stl_eval12 = stl.boolean_eval.pointwise_sat(phi5) stl_eval12 = stl.boolean_eval.pointwise_sat(phi5)
assert stl_eval12(x, 0) assert stl_eval12(x, 0)
phi6 = stl.parse('G[0.1, 0.03](~(AP5))') phi6 = stl.parse('G[0.1, 0.03] ~ap5')
stl_eval13 = stl.boolean_eval.pointwise_sat(phi6) stl_eval13 = stl.boolean_eval.pointwise_sat(phi6)
assert stl_eval13(x, 0) assert stl_eval13(x, 0)
assert not stl_eval13(x, 0.4) assert stl_eval13(x, 0.4)
phi7 = stl.parse('G(~(AP4))') phi7 = stl.parse('G ~ap4')
stl_eval14 = stl.boolean_eval.pointwise_sat(phi7) stl_eval14 = stl.boolean_eval.pointwise_sat(phi7)
assert stl_eval14(x, 0) assert stl_eval14(x, 0)
phi8 = stl.parse('F(AP5)') phi8 = stl.parse('F ap5')
stl_eval15 = stl.boolean_eval.pointwise_sat(phi8) stl_eval15 = stl.boolean_eval.pointwise_sat(phi8)
assert stl_eval15(x, 0) assert stl_eval15(x, 0)
phi9 = stl.parse('(AP1) U (AP2)') phi9 = stl.parse('(ap1 U ap2)')
stl_eval16 = stl.boolean_eval.pointwise_sat(phi9) stl_eval16 = stl.boolean_eval.pointwise_sat(phi9)
assert stl_eval16(x, 0) assert stl_eval16(x, 0)
phi10 = stl.parse('(AP2) U (AP2)') phi10 = stl.parse('(ap2 U ap2)')
stl_eval17 = stl.boolean_eval.pointwise_sat(phi10) stl_eval17 = stl.boolean_eval.pointwise_sat(phi10)
assert not stl_eval17(x, 0) assert not stl_eval17(x, 0)
@ -111,7 +113,7 @@ def test_fastboolean_equiv(phi):
def test_fastboolean_smoketest(): def test_fastboolean_smoketest():
phi = stl.parse( phi = stl.parse(
'(G[0, 4](x > 0)) & ((F[2, 1](AP1)) | (AP2)) & (G[0,0](AP2))') '(((G[0, 4] ap6 & F[2, 1] ap1) | ap2) & G[0,0](ap2))')
stl_eval = stl.fastboolean_eval.pointwise_sat(phi) stl_eval = stl.fastboolean_eval.pointwise_sat(phi)
assert not stl_eval(x, 0) assert not stl_eval(x, 0)
@ -121,13 +123,5 @@ def test_fastboolean_smoketest():
def test_callable_interface(): def test_callable_interface():
phi = stl.parse( phi = stl.parse(
'(G[0, 4](x > 0)) & ((F[2, 1](AP1)) | (AP2)) & (G[0,0](AP2))') '(((G[0, 4] ap6 & F[2, 1] ap1) | ap2) & G[0,0](ap2))')
assert not phi(x, 0) assert not phi(x, 0)
def test_implicit_validity_domain_rigid():
phi = stl.parse('(G[0, a?](x > b?)) & ((F(AP1)) | (AP2))')
vals = {'a?': 3, 'b?': 20}
stl_eval = stl.pointwise_sat(phi.set_params(vals))
oracle, order = stl.utils.implicit_validity_domain(phi, x)
assert stl_eval(x, 0) == oracle([vals.get(k) for k in order])

View file

@ -7,12 +7,12 @@ from hypothesis import given
@given(st.integers(), st.integers(), st.integers()) @given(st.integers(), st.integers(), st.integers())
def test_params1(a, b, c): def test_params1(a, b, c):
phi = stl.parse('G[a?, b?](x > c?)') phi = stl.parse('G[a, b] x')
assert {x.name for x in phi.params} == {'a?', 'b?', 'c?'} assert {x.name for x in phi.params} == {'a', 'b'}
phi2 = phi.set_params({'a?': a, 'b?': b, 'c?': c}) phi2 = phi.set_params({'a': a, 'b': b})
assert phi2.params == set() assert phi2.params == set()
assert phi2 == stl.parse(f'G[{a}, {b}](x > {c})') assert phi2 == stl.parse(f'G[{a}, {b}](x)')
@given(SignalTemporalLogicStrategy) @given(SignalTemporalLogicStrategy)

View file

@ -17,6 +17,6 @@ def test_hash_inheritance(phi):
def test_sugar_smoke(): def test_sugar_smoke():
stl.parse('(x) <-> (x)') stl.parse('(x <-> x)')
stl.parse('(x) -> (x)') stl.parse('(x -> x)')
stl.parse('(x) ^ (x)') stl.parse('(x ^ x)')

View file

@ -5,11 +5,11 @@ from hypothesis import given
from pytest import raises from pytest import raises
CONTEXT = { CONTEXT = {
stl.parse('AP1'): stl.parse('F(x > 4)'), stl.parse('ap1'): stl.parse('x'),
stl.parse('AP2'): stl.parse('(AP1) U (AP1)'), stl.parse('ap2'): stl.parse('(y U z)'),
stl.parse('AP3'): stl.parse('y < 4'), stl.parse('ap3'): stl.parse('x'),
stl.parse('AP4'): stl.parse('y < 3'), stl.parse('ap4'): stl.parse('(x -> y -> z)'),
stl.parse('AP5'): stl.parse('y + x > 4'), stl.parse('ap5'): stl.parse('(ap1 <-> y <-> z)'),
} }
APS = set(CONTEXT.keys()) APS = set(CONTEXT.keys())
@ -29,13 +29,13 @@ def test_f_neg_or_canonical_form_not_implemented():
def test_inline_context_rigid(): def test_inline_context_rigid():
phi = stl.parse('G(AP1)') phi = stl.parse('G ap1')
phi2 = phi.inline_context(CONTEXT) phi2 = phi.inline_context(CONTEXT)
assert phi2 == stl.parse('G(F(x > 4))') assert phi2 == stl.parse('G x')
phi = stl.parse('G(AP2)') phi = stl.parse('G ap5')
phi2 = phi.inline_context(CONTEXT) phi2 = phi.inline_context(CONTEXT)
assert phi2 == stl.parse('G((F(x > 4)) U (F(x > 4)))') assert phi2 == stl.parse('G(x <-> y <-> z)')
@given(SignalTemporalLogicStrategy) @given(SignalTemporalLogicStrategy)
@ -44,19 +44,6 @@ def test_inline_context(phi):
assert not (APS & phi2.atomic_predicates) assert not (APS & phi2.atomic_predicates)
def test_linear_stl_lipschitz_rigid():
phi = stl.parse('(x + 3y - 4z < 3)')
assert stl.utils.linear_stl_lipschitz(phi) == (8)
@given(SignalTemporalLogicStrategy, SignalTemporalLogicStrategy)
def test_linear_stl_lipschitz(phi1, phi2):
lip1 = stl.utils.linear_stl_lipschitz(phi1)
lip2 = stl.utils.linear_stl_lipschitz(phi2)
phi3 = phi1 | phi2
assert stl.utils.linear_stl_lipschitz(phi3) == max(lip1, lip2)
@given(SignalTemporalLogicStrategy, SignalTemporalLogicStrategy) @given(SignalTemporalLogicStrategy, SignalTemporalLogicStrategy)
def test_timed_until_smoke_test(phi1, phi2): def test_timed_until_smoke_test(phi1, phi2):
stl.utils.timed_until(phi1, phi2, lo=2, hi=20) stl.utils.timed_until(phi1, phi2, lo=2, hi=20)
@ -65,34 +52,34 @@ def test_timed_until_smoke_test(phi1, phi2):
def test_discretize(): def test_discretize():
dt = 0.3 dt = 0.3
phi = stl.parse('X(AP1)') phi = stl.parse('@ ap1')
assert stl.utils.is_discretizable(phi, dt) assert stl.utils.is_discretizable(phi, dt)
phi2 = stl.utils.discretize(phi, dt) phi2 = stl.utils.discretize(phi, dt)
phi3 = stl.utils.discretize(phi2, dt) phi3 = stl.utils.discretize(phi2, dt)
assert phi2 == phi3 assert phi2 == phi3
phi = stl.parse('G[0.3, 1.2](F[0.6, 1.5](AP1))') phi = stl.parse('G[0.3, 1.2] F[0.6, 1.5] ap1')
assert stl.utils.is_discretizable(phi, dt) assert stl.utils.is_discretizable(phi, dt)
phi2 = stl.utils.discretize(phi, dt) phi2 = stl.utils.discretize(phi, dt)
phi3 = stl.utils.discretize(phi2, dt) phi3 = stl.utils.discretize(phi2, dt)
assert phi2 == phi3 assert phi2 == phi3
phi = stl.parse('G[0.3, 1.4](F[0.6, 1.5](AP1))') phi = stl.parse('G[0.3, 1.4] F[0.6, 1.5] ap1')
assert not stl.utils.is_discretizable(phi, dt) assert not stl.utils.is_discretizable(phi, dt)
phi = stl.parse('G[0.3, 1.2](F(AP1))') phi = stl.parse('G[0.3, 1.2] F ap1')
assert not stl.utils.is_discretizable(phi, dt) assert not stl.utils.is_discretizable(phi, dt)
phi = stl.parse('G[0.3, 1.2]((AP1) U (AP2))') phi = stl.parse('G[0.3, 1.2] (ap1 U ap2)')
assert not stl.utils.is_discretizable(phi, dt) assert not stl.utils.is_discretizable(phi, dt)
phi = stl.parse('G[0.3, 0.6](~(F[0, 0.3](A)))') phi = stl.parse('G[0.3, 0.6] ~F[0, 0.3] a')
assert stl.utils.is_discretizable(phi, dt) assert stl.utils.is_discretizable(phi, dt)
phi2 = stl.utils.discretize(phi, dt, distribute=True) phi2 = stl.utils.discretize(phi, dt, distribute=True)
phi3 = stl.utils.discretize(phi2, dt, distribute=True) phi3 = stl.utils.discretize(phi2, dt, distribute=True)
assert phi2 == phi3 assert phi2 == phi3
assert phi2 == stl.parse( assert phi2 == stl.parse(
'(~((X(A)) (X(X(A))))) ∧ (~((X(X(A))) (X(X(X(A))))))') '(~(@a | @@a) & ~(@@a | @@@a))')
phi = stl.TOP phi = stl.TOP
assert stl.utils.is_discretizable(phi, dt) assert stl.utils.is_discretizable(phi, dt)
@ -110,17 +97,17 @@ def test_discretize():
def test_scope(): def test_scope():
dt = 0.3 dt = 0.3
phi = stl.parse('X(AP1)') phi = stl.parse('@ap1')
assert stl.utils.scope(phi, dt) == 0.3 assert stl.utils.scope(phi, dt) == 0.3
phi = stl.parse('X((X(AP1)) | (AP2))') phi = stl.parse('(@@ap1 | ap2)')
assert stl.utils.scope(phi, dt) == 0.6 assert stl.utils.scope(phi, dt) == 0.6
phi = stl.parse('G[0.3, 1.2](F[0.6, 1.5](AP1))') phi = stl.parse('G[0.3, 1.2] F[0.6, 1.5] ap1')
assert stl.utils.scope(phi, dt) == 1.2 + 1.5 assert stl.utils.scope(phi, dt) == 1.2 + 1.5
phi = stl.parse('G[0.3, 1.2](F(AP1))') phi = stl.parse('G[0.3, 1.2] F ap1')
assert stl.utils.scope(phi, dt) == float('inf') assert stl.utils.scope(phi, dt) == float('inf')
phi = stl.parse('G[0.3, 1.2]((AP1) U (AP2))') phi = stl.parse('G[0.3, 1.2] (ap1 U ap2)')
assert stl.utils.scope(phi, dt) == float('inf') assert stl.utils.scope(phi, dt) == float('inf')

View file

@ -85,19 +85,6 @@ def eval_lineqs(phi, x):
return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs} return {lineq: eval_lineq(lineq, x, domain) for lineq in lineqs}
def implicit_validity_domain(phi, trace):
params = {ap.name for ap in phi.params}
order = tuple(params)
def vec_to_dict(theta):
return {k: v for k, v in zip(order, theta)}
def oracle(theta):
return stl.pointwise_sat(phi.set_params(vec_to_dict(theta)))(trace, 0)
return oracle, order
def require_discretizable(func): def require_discretizable(func):
@wraps(func) @wraps(func)
def _func(phi, dt, *args, **kwargs): def _func(phi, dt, *args, **kwargs):