restructing to move repo
This commit is contained in:
parent
d658907c4c
commit
4c46db7f71
15 changed files with 50 additions and 0 deletions
6
stl/__init__.py
Normal file
6
stl/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
from stl.utils import terms_lens, lineq_lens, walk, tree, and_or_lens
|
||||
from stl.ast import dt_sym, t_sym
|
||||
from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var
|
||||
from stl.parser import parse
|
||||
from stl.synth import lex_param_project
|
||||
from stl.boolean_eval import pointwise_sat
|
||||
86
stl/ast.py
Normal file
86
stl/ast.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# TODO: create iso lens between sugar and non-sugar
|
||||
# TODO: supress + given a + (-b). i.e. want a - b
|
||||
|
||||
from collections import namedtuple, deque
|
||||
from itertools import repeat
|
||||
from typing import Union
|
||||
from enum import Enum
|
||||
from sympy import Symbol
|
||||
|
||||
VarKind = Enum("VarKind", ["x", "u", "w"])
|
||||
str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w}
|
||||
dt_sym = Symbol('dt', positive=True)
|
||||
t_sym = Symbol('t', positive=True)
|
||||
|
||||
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
|
||||
def __repr__(self):
|
||||
n = len(self.terms)
|
||||
rep = "{}"
|
||||
if n > 1:
|
||||
rep += " + {}"*(n - 1)
|
||||
rep += " {op} {c}"
|
||||
return rep.format(*self.terms, op=self.op, c=self.const)
|
||||
|
||||
def children(self):
|
||||
return []
|
||||
|
||||
|
||||
class Var(namedtuple("Var", ["coeff", "id", "time"])):
|
||||
def __repr__(self):
|
||||
time_str = "[{}]".format(self.time)
|
||||
return "{c}*{i}{t}".format(c=self.coeff, i=self.id, t=time_str)
|
||||
|
||||
|
||||
class Interval(namedtuple('I', ['lower', 'upper'])):
|
||||
def __repr__(self):
|
||||
return "[{},{}]".format(self.lower, self.upper)
|
||||
|
||||
def children(self):
|
||||
return [self.lower, self.upper]
|
||||
|
||||
|
||||
class NaryOpSTL(namedtuple('NaryOp', ['args'])):
|
||||
OP = "?"
|
||||
def __repr__(self):
|
||||
n = len(self.args)
|
||||
if n == 1:
|
||||
return "{}".format(self.args[0])
|
||||
elif self.args:
|
||||
rep = " {op} ".join(["({})"]*(len(self.args)))
|
||||
return rep.format(*self.args, op=self.OP)
|
||||
else:
|
||||
return ""
|
||||
|
||||
def children(self):
|
||||
return self.args
|
||||
|
||||
|
||||
class Or(NaryOpSTL):
|
||||
OP = "∨"
|
||||
|
||||
class And(NaryOpSTL):
|
||||
OP = "∧"
|
||||
|
||||
|
||||
class ModalOp(namedtuple('ModalOp', ['interval', 'arg'])):
|
||||
def children(self):
|
||||
return [self.arg]
|
||||
|
||||
|
||||
class F(ModalOp):
|
||||
def __repr__(self):
|
||||
return "◇{}({})".format(self.interval, self.arg)
|
||||
|
||||
|
||||
class G(ModalOp):
|
||||
def __repr__(self):
|
||||
return "□{}({})".format(self.interval, self.arg)
|
||||
|
||||
|
||||
class Neg(namedtuple('Neg', ['arg'])):
|
||||
def __repr__(self):
|
||||
return "¬({})".format(self.arg)
|
||||
|
||||
def children(self):
|
||||
return [self.arg]
|
||||
68
stl/boolean_eval.py
Normal file
68
stl/boolean_eval.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
# TODO: figure out how to deduplicate this with robustness
|
||||
# - Abstract as working on distributive lattice
|
||||
|
||||
from functools import singledispatch
|
||||
import operator as op
|
||||
|
||||
import numpy as np
|
||||
from lenses import lens
|
||||
|
||||
import stl.ast
|
||||
|
||||
@singledispatch
|
||||
def pointwise_sat(stl):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.Or)
|
||||
def _(stl):
|
||||
return lambda x, t: any(pointwise_sat(arg)(x, t) for arg in stl.args)
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.And)
|
||||
def _(stl):
|
||||
return lambda x, t: all(pointwise_sat(arg)(x, t) for arg in stl.args)
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.F)
|
||||
def _(stl):
|
||||
lo, hi = stl.interval
|
||||
return lambda x, t: any((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1]))
|
||||
for t2 in x[lo:hi].index))
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.G)
|
||||
def _(stl):
|
||||
lo, hi = stl.interval
|
||||
return lambda x, t: all((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1]))
|
||||
for t2 in x[lo:hi].index))
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.Neg)
|
||||
def _(stl):
|
||||
return lambda x, t: not pointwise_sat(arg)(x, t)
|
||||
|
||||
|
||||
op_lookup = {
|
||||
">": op.gt,
|
||||
">=": op.ge,
|
||||
"<": op.lt,
|
||||
"<=": op.le,
|
||||
"=": op.eq,
|
||||
}
|
||||
|
||||
|
||||
@pointwise_sat.register(stl.LinEq)
|
||||
def _(stl):
|
||||
op = op_lookup[stl.op]
|
||||
return lambda x, t: op(eval_terms(stl, x, t), stl.const)
|
||||
|
||||
|
||||
def eval_terms(lineq, x, t):
|
||||
psi = lens(lineq).terms.each_().modify(eval_term(x, t))
|
||||
return sum(psi.terms)
|
||||
|
||||
|
||||
def eval_term(x, t):
|
||||
# TODO(lift interpolation much higher)
|
||||
return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name])
|
||||
150
stl/parser.py
Normal file
150
stl/parser.py
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
|
||||
# TODO: break out into seperate library
|
||||
# TODO: allow multiplication to be distributive
|
||||
# TODO: support reference specific time points
|
||||
# TODO: add Implies and Iff syntactic sugar
|
||||
# TODO: add support for parsing Until
|
||||
# TODO: support variables on both sides of ineq
|
||||
# TODO: Allow -x = -1*x
|
||||
|
||||
from functools import partialmethod
|
||||
from collections import namedtuple
|
||||
import operator as op
|
||||
|
||||
from parsimonious import Grammar, NodeVisitor
|
||||
from funcy import flatten
|
||||
from lenses import lens
|
||||
|
||||
from sympy import Symbol, Number
|
||||
|
||||
from stl import ast
|
||||
|
||||
STL_GRAMMAR = Grammar(u'''
|
||||
phi = (g / f / lineq / or / and / paren_phi)
|
||||
|
||||
paren_phi = "(" __ phi __ ")"
|
||||
|
||||
or = paren_phi _ ("∨" / "or") _ (or / paren_phi)
|
||||
and = paren_phi _ ("∧" / "and") _ (and / paren_phi)
|
||||
|
||||
f = F interval phi
|
||||
g = G interval phi
|
||||
|
||||
F = "F" / "◇"
|
||||
G = "G" / "□"
|
||||
interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]"
|
||||
|
||||
const_or_unbound = unbound / const
|
||||
|
||||
lineq = terms _ op _ const_or_unbound
|
||||
term = coeff? var
|
||||
coeff = (dt __ "*" __)? const __ "*" __
|
||||
terms = (term __ pm __ terms) / term
|
||||
|
||||
var = id time?
|
||||
time = prime / time_index
|
||||
time_index = "[" "t" __ pm __ const "]"
|
||||
prime = "'"
|
||||
|
||||
pm = "+" / "-"
|
||||
dt = "dt"
|
||||
unbound = id "?"
|
||||
id = ~r"[a-zA-z\d]+"
|
||||
const = ~r"[\+\-]?\d*(\.\d+)?"
|
||||
op = ">=" / "<=" / "<" / ">" / "="
|
||||
_ = ~r"\s"+
|
||||
__ = ~r"\s"*
|
||||
EOL = "\\n"
|
||||
''')
|
||||
|
||||
|
||||
class STLVisitor(NodeVisitor):
|
||||
def generic_visit(self, _, children):
|
||||
return children
|
||||
|
||||
def children_getter(self, _, children, i):
|
||||
return children[i]
|
||||
|
||||
visit_phi = partialmethod(children_getter, i=0)
|
||||
visit_paren_phi = partialmethod(children_getter, i=2)
|
||||
|
||||
def visit_interval(self, _, children):
|
||||
_, _, left, _, _, _, right, _, _ = children
|
||||
return ast.Interval(left[0], right[0])
|
||||
|
||||
def get_text(self, node, _):
|
||||
return node.text
|
||||
|
||||
def visit_unbound(self, node, _):
|
||||
return Symbol(node.text)
|
||||
|
||||
visit_op = get_text
|
||||
|
||||
def unary_temp_op_visitor(self, _, children, op):
|
||||
_, interval, phi = children
|
||||
return op(interval, phi)
|
||||
|
||||
def binop_visitor(self, _, children, op):
|
||||
phi1, _, _, _, (phi2,) = children
|
||||
argL = list(phi1.args) if isinstance(phi1, op) else [phi1]
|
||||
argR = list(phi2.args) if isinstance(phi2, op) else [phi2]
|
||||
return op(tuple(argL + argR))
|
||||
|
||||
visit_f = partialmethod(unary_temp_op_visitor, op=ast.F)
|
||||
visit_g = partialmethod(unary_temp_op_visitor, op=ast.G)
|
||||
visit_or = partialmethod(binop_visitor, op=ast.Or)
|
||||
visit_and = partialmethod(binop_visitor, op=ast.And)
|
||||
|
||||
def visit_id(self, name, _):
|
||||
return Symbol(name.text)
|
||||
|
||||
def visit_var(self, _, children):
|
||||
iden, time_node = children
|
||||
|
||||
time_node = list(flatten(time_node))
|
||||
time = time_node[0] if len(time_node) > 0 else ast.t_sym
|
||||
|
||||
return iden, time
|
||||
|
||||
def visit_time_index(self, _, children):
|
||||
return children[3]* children[5]
|
||||
|
||||
def visit_prime(self, *_):
|
||||
return ast.t_sym - ast.dt_sym
|
||||
|
||||
def visit_const(self, const, children):
|
||||
return float(const.text)
|
||||
|
||||
def visit_dt(self, *_):
|
||||
return ast.dt_sym
|
||||
|
||||
def visit_term(self, _, children):
|
||||
coeffs, (iden, time) = children
|
||||
c = coeffs[0] if coeffs else Number(1)
|
||||
return ast.Var(coeff=c, id=iden, time=time)
|
||||
|
||||
|
||||
def visit_coeff(self, _, children):
|
||||
dt, coeff, *_ = children
|
||||
dt = dt[0][0] if dt else Number(1)
|
||||
return dt * coeff
|
||||
|
||||
def visit_terms(self, _, children):
|
||||
if isinstance(children[0], list):
|
||||
term, _1, sgn ,_2, terms = children[0]
|
||||
terms = lens(terms)[0].coeff * sgn
|
||||
return [term] + terms
|
||||
else:
|
||||
return children
|
||||
|
||||
def visit_lineq(self, _, children):
|
||||
terms, _1, op, _2, const = children
|
||||
return ast.LinEq(tuple(terms), op, const[0])
|
||||
|
||||
def visit_pm(self, node, _):
|
||||
return Number(1) if node.text == "+" else Number(-1)
|
||||
|
||||
|
||||
def parse(stl_str:str, rule:str="phi") -> "STL":
|
||||
return STLVisitor().visit(STL_GRAMMAR[rule].parse(stl_str))
|
||||
70
stl/robustness.py
Normal file
70
stl/robustness.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
# TODO: technically incorrect on 0 robustness since conflates < and >
|
||||
|
||||
from functools import singledispatch
|
||||
from operator import sub, add
|
||||
|
||||
import numpy as np
|
||||
from lenses import lens
|
||||
|
||||
import stl.ast
|
||||
|
||||
oo = float('inf')
|
||||
|
||||
@singledispatch
|
||||
def pointwise_robustness(stl):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pointwise_robustness.register(stl.Or)
|
||||
def _(stl):
|
||||
return lambda x, t: max(pointwise_robustness(arg)(x, t) for arg in stl.args)
|
||||
|
||||
|
||||
@pointwise_robustness.register(stl.And)
|
||||
def _(stl):
|
||||
return lambda x, t: min(pointwise_robustness(arg)(x, t) for arg in stl.args)
|
||||
|
||||
|
||||
@pointwise_robustness.register(stl.F)
|
||||
def _(stl):
|
||||
lo, hi = stl.interval
|
||||
return lambda x, t: max((pointwise_robustness(stl.arg)(x, t + t2)
|
||||
for t2 in x[lo:hi].index), default=-oo)
|
||||
|
||||
|
||||
@pointwise_robustness.register(stl.G)
|
||||
def _(stl):
|
||||
lo, hi = stl.interval
|
||||
return lambda x, t: min((pointwise_robustness(stl.arg)(x, t + t2)
|
||||
for t2 in x[lo:hi].index), default=oo)
|
||||
|
||||
|
||||
@pointwise_robustness.register(stl.Neg)
|
||||
def _(stl):
|
||||
return lambda x, t: -pointwise_robustness(arg)(x, t)
|
||||
|
||||
|
||||
op_lookup = {
|
||||
">": sub,
|
||||
">=": sub,
|
||||
"<": lambda x, y: sub(y, x),
|
||||
"<=": lambda x, y: sub(y, x),
|
||||
"=": lambda a, b: -abs(a - b),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@pointwise_robustness.register(stl.LinEq)
|
||||
def _(stl):
|
||||
op = op_lookup[stl.op]
|
||||
return lambda x, t: op(eval_terms(stl, x, t), stl.const)
|
||||
|
||||
|
||||
def eval_terms(lineq, x, t):
|
||||
psi = lens(lineq).terms.each_().modify(eval_term(x, t))
|
||||
return sum(psi.terms)
|
||||
|
||||
|
||||
def eval_term(x, t):
|
||||
# TODO(lift interpolation much higher)
|
||||
return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name])
|
||||
38
stl/synth.py
Normal file
38
stl/synth.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import operator as op
|
||||
|
||||
from stl.utils import set_params, param_lens
|
||||
from stl.boolean_eval import pointwise_sat
|
||||
|
||||
from lenses import lens
|
||||
|
||||
def binsearch(stleval, *, tol=1e-3, lo, hi, polarity):
|
||||
"""Only run search if tightest robustness was positive."""
|
||||
# Early termination via bounds checks
|
||||
if polarity and stleval(lo):
|
||||
return lo
|
||||
elif not polarity and stleval(hi):
|
||||
return hi
|
||||
|
||||
while hi - lo > tol:
|
||||
mid = lo + (hi - lo) / 2
|
||||
r = stleval(mid)
|
||||
lo, hi = (mid, hi) if r ^ polarity else (lo, mid)
|
||||
|
||||
# Want satisifiable formula
|
||||
return hi if polarity else lo
|
||||
|
||||
|
||||
def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3):
|
||||
val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order}
|
||||
p_lens = param_lens(stl)
|
||||
def stleval_fact(var, val):
|
||||
l = lens(val)[var]
|
||||
return lambda p: pointwise_sat(set_params(stl, l.set(p)))(x, 0)
|
||||
|
||||
for var in order:
|
||||
stleval = stleval_fact(var, val)
|
||||
lo, hi = ranges[var]
|
||||
param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var])
|
||||
val[var] = param
|
||||
|
||||
return val
|
||||
22
stl/test_boolean_eval.py
Normal file
22
stl/test_boolean_eval.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import stl
|
||||
import stl.boolean_eval
|
||||
import pandas as pd
|
||||
from nose2.tools import params
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
ex1 = ("2*A > 3", False)
|
||||
ex2 = ("F[0, 1](2*A > 3)", True)
|
||||
ex3 = ("F[1, 0](2*A > 3)", False)
|
||||
ex4 = ("G[1, 0](2*A > 3)", True)
|
||||
ex5 = ("(A < 0)", False)
|
||||
ex6 = ("G[0, 0.1](A < 0)", False)
|
||||
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
|
||||
columns=["A", "B"])
|
||||
|
||||
class TestSTLRobustness(unittest.TestCase):
|
||||
@params(ex1, ex2, ex3, ex4, ex5, ex6)
|
||||
def test_stl(self, phi_str, r):
|
||||
phi = stl.parse(phi_str)
|
||||
stl_eval = stl.boolean_eval.pointwise_sat(phi)
|
||||
self.assertEqual(stl_eval(x, 0), r)
|
||||
31
stl/test_parser.py
Normal file
31
stl/test_parser.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import stl
|
||||
from nose2.tools import params
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
ex1 = ('x1 > 2', stl.LinEq(
|
||||
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),),
|
||||
">",
|
||||
2.0
|
||||
))
|
||||
ex1_ = ('x1 > a?', stl.LinEq(
|
||||
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),),
|
||||
">",
|
||||
Symbol("a?")
|
||||
))
|
||||
|
||||
i1 = stl.Interval(0., 1.)
|
||||
i1_ = stl.Interval(0., Symbol("b?"))
|
||||
i2 = stl.Interval(2., 3.)
|
||||
ex2 = ('◇[0,1](x1 > 2)', stl.F(i1, ex1[1]))
|
||||
ex3 = ('□[2,3]◇[0,1](x1 > 2)', stl.G(i2, ex2[1]))
|
||||
ex4 = ('(x1 > 2) or ((x1 > 2) or (x1 > 2))',
|
||||
stl.Or((ex1[1], ex1[1], ex1[1])))
|
||||
ex5 = ('G[0, b?](x1 > a?)',
|
||||
stl.G(i1_, ex1_))
|
||||
|
||||
class TestSTLParser(unittest.TestCase):
|
||||
@params(ex1, ex2, ex3, ex4)
|
||||
def test_stl(self, phi_str, phi):
|
||||
self.assertEqual(stl.parse(phi_str), phi)
|
||||
25
stl/test_robustness.py
Normal file
25
stl/test_robustness.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import stl
|
||||
import stl.robustness
|
||||
import pandas as pd
|
||||
from nose2.tools import params
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
oo = float('inf')
|
||||
|
||||
ex1 = ("2*A > 3", -1)
|
||||
ex2 = ("F[0, 1](2*A > 3)", 5)
|
||||
ex3 = ("F[1, 0](2*A > 3)", -oo)
|
||||
ex4 = ("G[1, 0](2*A > 3)", oo)
|
||||
ex5 = ("(A < 0)", -1)
|
||||
ex6 = ("G[0, 0.1](A < 0)", -1)
|
||||
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
|
||||
columns=["A", "B"])
|
||||
|
||||
|
||||
class TestSTLRobustness(unittest.TestCase):
|
||||
@params(ex1, ex2, ex3, ex4, ex5, ex6)
|
||||
def test_stl(self, phi_str, r):
|
||||
phi = stl.parse(phi_str)
|
||||
stl_eval = stl.robustness.pointwise_robustness(phi)
|
||||
self.assertEqual(stl_eval(x, 0), r)
|
||||
45
stl/test_synth.py
Normal file
45
stl/test_synth.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import stl
|
||||
import stl.robustness
|
||||
import stl.synth
|
||||
import pandas as pd
|
||||
from nose2.tools import params
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
oo = float('inf')
|
||||
|
||||
ex1 = ("A > a?", ("a?",), {"a?": (0, 10)}, {"a?": False}, {"a?": 1})
|
||||
ex2 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 5)},
|
||||
{"a?": False, "b?": True}, {"a?": 4, "b?": 0.2})
|
||||
ex3 = ("F[0, b?](A < 0)", ("b?",), {"b?": (0, 5)},
|
||||
{"b?": True}, {"b?": 5})
|
||||
ex4 = ("G[0, b?](A < 0)", ("b?",), {"b?": (0.1, 5)},
|
||||
{"b?": False}, {"b?": 0.1})
|
||||
ex5 = ("F[0, b?](A > 0)", ("b?",), {"b?": (0.1, 5)},
|
||||
{"b?": True}, {"b?": 0.1})
|
||||
ex6 = ("(A > a?) or (A > b?)", ("a?", "b?",), {"a?": (0, 2), "b?": (0, 2)},
|
||||
{"a?": False, "b?": False}, {"a?": 2, "b?": 1})
|
||||
|
||||
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
|
||||
columns=["A", "B"])
|
||||
|
||||
|
||||
class TestSTLRobustness(unittest.TestCase):
|
||||
@params(ex1, ex2, ex3, ex4, ex5, ex6)
|
||||
def test_lex_synth(self, phi_str, order, ranges, polarity, val):
|
||||
phi = stl.parse(phi_str)
|
||||
val2 = stl.synth.lex_param_project(
|
||||
phi, x, order=order, ranges=ranges, polarity=polarity)
|
||||
|
||||
phi2 = stl.utils.set_params(phi, val2)
|
||||
phi3 = stl.utils.set_params(phi, val)
|
||||
|
||||
stl_eval = stl.robustness.pointwise_robustness(phi2)
|
||||
stl_eval2 = stl.robustness.pointwise_robustness(phi3)
|
||||
|
||||
# check that the robustnesses are almost the same
|
||||
self.assertAlmostEqual(stl_eval(x, 0), stl_eval2(x, 0), delta=0.01)
|
||||
|
||||
# check that the valuations are almost the same
|
||||
for var in order:
|
||||
self.assertAlmostEqual(val2[var], val[var], delta=0.01)
|
||||
30
stl/test_utils.py
Normal file
30
stl/test_utils.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
import stl
|
||||
import stl.utils
|
||||
import pandas as pd
|
||||
from nose2.tools import params
|
||||
import unittest
|
||||
from sympy import Symbol
|
||||
|
||||
ex1 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"})
|
||||
ex2 = ("G[0, c?](x > a?)", {"a?", "c?"})
|
||||
ex3 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"})
|
||||
|
||||
ex4 = ("F[b?, 1]G[0, c?](x > a?)", "F[2, 1]G[0, 3](x > 1)")
|
||||
ex5 = ("G[0, c?](x > a?)", "G[0, 3](x > 1)")
|
||||
|
||||
val = {"a?": 1.0, "b?": 2.0, "c?": 3.0}
|
||||
|
||||
class TestSTLUtils(unittest.TestCase):
|
||||
@params(ex1, ex2, ex3)
|
||||
def test_param_lens(self, phi_str, params):
|
||||
phi = stl.parse(phi_str)
|
||||
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), params)
|
||||
|
||||
@params(ex4, ex5)
|
||||
def test_set_params(self, phi_str, phi2_str):
|
||||
phi = stl.parse(phi_str)
|
||||
phi2 = stl.parse(phi2_str)
|
||||
phi = stl.utils.set_params(phi, val)
|
||||
|
||||
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set())
|
||||
self.assertEqual(phi, phi2)
|
||||
76
stl/utils.py
Normal file
76
stl/utils.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
from collections import deque
|
||||
|
||||
from lenses import lens, Lens
|
||||
import funcy as fn
|
||||
import sympy
|
||||
|
||||
from stl.ast import LinEq, And, Or, NaryOpSTL, F, G, Interval
|
||||
|
||||
def walk(stl, bfs=False):
|
||||
"""Walks Ast. Defaults to DFS unless BFS flag is set."""
|
||||
pop = deque.popleft if bfs else deque.pop
|
||||
children = deque([stl])
|
||||
while len(children) != 0:
|
||||
node = pop(children)
|
||||
yield node
|
||||
children.extend(node.children())
|
||||
|
||||
|
||||
def tree(stl):
|
||||
return {x:set(x.children()) for x in walk(stl) if x.children()}
|
||||
|
||||
|
||||
def type_pred(*args):
|
||||
ast_types = set(args)
|
||||
return lambda x: type(x) in ast_types
|
||||
|
||||
|
||||
def _child_lens(psi, focus):
|
||||
if psi is None:
|
||||
return
|
||||
if isinstance(psi, NaryOpSTL):
|
||||
for j, _ in enumerate(psi.args):
|
||||
yield focus.args[j]
|
||||
else:
|
||||
yield focus.arg
|
||||
|
||||
|
||||
def ast_lens(phi:"STL", bind=True, *, pred, focus_lens=None) -> lens:
|
||||
if focus_lens is None:
|
||||
focus_lens = lambda x: [lens()]
|
||||
tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens)))
|
||||
tl = lens().tuple_(*tls).each_()
|
||||
return tl.bind(phi) if bind else tl
|
||||
|
||||
|
||||
def _ast_lens(phi, *, pred, focus=lens(), focus_lens):
|
||||
psi = focus.get(state=phi)
|
||||
ret_lens = [focus.add_lens(l) for l in focus_lens(psi)] if pred(psi) else []
|
||||
|
||||
if isinstance(psi, LinEq):
|
||||
return ret_lens
|
||||
|
||||
child_lenses = list(_child_lens(psi, focus=focus))
|
||||
ret_lens += [_ast_lens(phi, pred=pred, focus=cl, focus_lens=focus_lens)
|
||||
for cl in child_lenses]
|
||||
return ret_lens
|
||||
|
||||
|
||||
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq))
|
||||
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or))
|
||||
|
||||
def terms_lens(phi:"STL", bind=True) -> lens:
|
||||
return lineq_lens(phi, bind).terms.each_()
|
||||
|
||||
|
||||
def param_lens(phi):
|
||||
is_sym = lambda x: isinstance(x, sympy.Symbol)
|
||||
def focus_lens(leaf):
|
||||
return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
|
||||
|
||||
return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens).filter_(is_sym)
|
||||
|
||||
|
||||
def set_params(stl_or_lens, val):
|
||||
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens)
|
||||
return l.modify(lambda x: val[str(x)] if str(x) in val else x)
|
||||
Loading…
Add table
Add a link
Reference in a new issue