restructing to move repo

This commit is contained in:
Marcell Vazquez-Chanlatte 2016-11-02 17:54:47 -07:00
parent d658907c4c
commit 4c46db7f71
15 changed files with 50 additions and 0 deletions

6
stl/__init__.py Normal file
View file

@ -0,0 +1,6 @@
from stl.utils import terms_lens, lineq_lens, walk, tree, and_or_lens
from stl.ast import dt_sym, t_sym
from stl.ast import LinEq, Interval, NaryOpSTL, Or, And, F, G, ModalOp, Neg, Var
from stl.parser import parse
from stl.synth import lex_param_project
from stl.boolean_eval import pointwise_sat

86
stl/ast.py Normal file
View file

@ -0,0 +1,86 @@
# -*- coding: utf-8 -*-
# TODO: create iso lens between sugar and non-sugar
# TODO: supress + given a + (-b). i.e. want a - b
from collections import namedtuple, deque
from itertools import repeat
from typing import Union
from enum import Enum
from sympy import Symbol
VarKind = Enum("VarKind", ["x", "u", "w"])
str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w}
dt_sym = Symbol('dt', positive=True)
t_sym = Symbol('t', positive=True)
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
def __repr__(self):
n = len(self.terms)
rep = "{}"
if n > 1:
rep += " + {}"*(n - 1)
rep += " {op} {c}"
return rep.format(*self.terms, op=self.op, c=self.const)
def children(self):
return []
class Var(namedtuple("Var", ["coeff", "id", "time"])):
def __repr__(self):
time_str = "[{}]".format(self.time)
return "{c}*{i}{t}".format(c=self.coeff, i=self.id, t=time_str)
class Interval(namedtuple('I', ['lower', 'upper'])):
def __repr__(self):
return "[{},{}]".format(self.lower, self.upper)
def children(self):
return [self.lower, self.upper]
class NaryOpSTL(namedtuple('NaryOp', ['args'])):
OP = "?"
def __repr__(self):
n = len(self.args)
if n == 1:
return "{}".format(self.args[0])
elif self.args:
rep = " {op} ".join(["({})"]*(len(self.args)))
return rep.format(*self.args, op=self.OP)
else:
return ""
def children(self):
return self.args
class Or(NaryOpSTL):
OP = ""
class And(NaryOpSTL):
OP = ""
class ModalOp(namedtuple('ModalOp', ['interval', 'arg'])):
def children(self):
return [self.arg]
class F(ModalOp):
def __repr__(self):
return "{}({})".format(self.interval, self.arg)
class G(ModalOp):
def __repr__(self):
return "{}({})".format(self.interval, self.arg)
class Neg(namedtuple('Neg', ['arg'])):
def __repr__(self):
return "¬({})".format(self.arg)
def children(self):
return [self.arg]

68
stl/boolean_eval.py Normal file
View file

@ -0,0 +1,68 @@
# TODO: figure out how to deduplicate this with robustness
# - Abstract as working on distributive lattice
from functools import singledispatch
import operator as op
import numpy as np
from lenses import lens
import stl.ast
@singledispatch
def pointwise_sat(stl):
raise NotImplementedError
@pointwise_sat.register(stl.Or)
def _(stl):
return lambda x, t: any(pointwise_sat(arg)(x, t) for arg in stl.args)
@pointwise_sat.register(stl.And)
def _(stl):
return lambda x, t: all(pointwise_sat(arg)(x, t) for arg in stl.args)
@pointwise_sat.register(stl.F)
def _(stl):
lo, hi = stl.interval
return lambda x, t: any((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1]))
for t2 in x[lo:hi].index))
@pointwise_sat.register(stl.G)
def _(stl):
lo, hi = stl.interval
return lambda x, t: all((pointwise_sat(stl.arg)(x, min(t + t2, x.index[-1]))
for t2 in x[lo:hi].index))
@pointwise_sat.register(stl.Neg)
def _(stl):
return lambda x, t: not pointwise_sat(arg)(x, t)
op_lookup = {
">": op.gt,
">=": op.ge,
"<": op.lt,
"<=": op.le,
"=": op.eq,
}
@pointwise_sat.register(stl.LinEq)
def _(stl):
op = op_lookup[stl.op]
return lambda x, t: op(eval_terms(stl, x, t), stl.const)
def eval_terms(lineq, x, t):
psi = lens(lineq).terms.each_().modify(eval_term(x, t))
return sum(psi.terms)
def eval_term(x, t):
# TODO(lift interpolation much higher)
return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name])

150
stl/parser.py Normal file
View file

@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
# TODO: break out into seperate library
# TODO: allow multiplication to be distributive
# TODO: support reference specific time points
# TODO: add Implies and Iff syntactic sugar
# TODO: add support for parsing Until
# TODO: support variables on both sides of ineq
# TODO: Allow -x = -1*x
from functools import partialmethod
from collections import namedtuple
import operator as op
from parsimonious import Grammar, NodeVisitor
from funcy import flatten
from lenses import lens
from sympy import Symbol, Number
from stl import ast
STL_GRAMMAR = Grammar(u'''
phi = (g / f / lineq / or / and / paren_phi)
paren_phi = "(" __ phi __ ")"
or = paren_phi _ ("" / "or") _ (or / paren_phi)
and = paren_phi _ ("" / "and") _ (and / paren_phi)
f = F interval phi
g = G interval phi
F = "F" / ""
G = "G" / ""
interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]"
const_or_unbound = unbound / const
lineq = terms _ op _ const_or_unbound
term = coeff? var
coeff = (dt __ "*" __)? const __ "*" __
terms = (term __ pm __ terms) / term
var = id time?
time = prime / time_index
time_index = "[" "t" __ pm __ const "]"
prime = "'"
pm = "+" / "-"
dt = "dt"
unbound = id "?"
id = ~r"[a-zA-z\d]+"
const = ~r"[\+\-]?\d*(\.\d+)?"
op = ">=" / "<=" / "<" / ">" / "="
_ = ~r"\s"+
__ = ~r"\s"*
EOL = "\\n"
''')
class STLVisitor(NodeVisitor):
def generic_visit(self, _, children):
return children
def children_getter(self, _, children, i):
return children[i]
visit_phi = partialmethod(children_getter, i=0)
visit_paren_phi = partialmethod(children_getter, i=2)
def visit_interval(self, _, children):
_, _, left, _, _, _, right, _, _ = children
return ast.Interval(left[0], right[0])
def get_text(self, node, _):
return node.text
def visit_unbound(self, node, _):
return Symbol(node.text)
visit_op = get_text
def unary_temp_op_visitor(self, _, children, op):
_, interval, phi = children
return op(interval, phi)
def binop_visitor(self, _, children, op):
phi1, _, _, _, (phi2,) = children
argL = list(phi1.args) if isinstance(phi1, op) else [phi1]
argR = list(phi2.args) if isinstance(phi2, op) else [phi2]
return op(tuple(argL + argR))
visit_f = partialmethod(unary_temp_op_visitor, op=ast.F)
visit_g = partialmethod(unary_temp_op_visitor, op=ast.G)
visit_or = partialmethod(binop_visitor, op=ast.Or)
visit_and = partialmethod(binop_visitor, op=ast.And)
def visit_id(self, name, _):
return Symbol(name.text)
def visit_var(self, _, children):
iden, time_node = children
time_node = list(flatten(time_node))
time = time_node[0] if len(time_node) > 0 else ast.t_sym
return iden, time
def visit_time_index(self, _, children):
return children[3]* children[5]
def visit_prime(self, *_):
return ast.t_sym - ast.dt_sym
def visit_const(self, const, children):
return float(const.text)
def visit_dt(self, *_):
return ast.dt_sym
def visit_term(self, _, children):
coeffs, (iden, time) = children
c = coeffs[0] if coeffs else Number(1)
return ast.Var(coeff=c, id=iden, time=time)
def visit_coeff(self, _, children):
dt, coeff, *_ = children
dt = dt[0][0] if dt else Number(1)
return dt * coeff
def visit_terms(self, _, children):
if isinstance(children[0], list):
term, _1, sgn ,_2, terms = children[0]
terms = lens(terms)[0].coeff * sgn
return [term] + terms
else:
return children
def visit_lineq(self, _, children):
terms, _1, op, _2, const = children
return ast.LinEq(tuple(terms), op, const[0])
def visit_pm(self, node, _):
return Number(1) if node.text == "+" else Number(-1)
def parse(stl_str:str, rule:str="phi") -> "STL":
return STLVisitor().visit(STL_GRAMMAR[rule].parse(stl_str))

70
stl/robustness.py Normal file
View file

@ -0,0 +1,70 @@
# TODO: technically incorrect on 0 robustness since conflates < and >
from functools import singledispatch
from operator import sub, add
import numpy as np
from lenses import lens
import stl.ast
oo = float('inf')
@singledispatch
def pointwise_robustness(stl):
raise NotImplementedError
@pointwise_robustness.register(stl.Or)
def _(stl):
return lambda x, t: max(pointwise_robustness(arg)(x, t) for arg in stl.args)
@pointwise_robustness.register(stl.And)
def _(stl):
return lambda x, t: min(pointwise_robustness(arg)(x, t) for arg in stl.args)
@pointwise_robustness.register(stl.F)
def _(stl):
lo, hi = stl.interval
return lambda x, t: max((pointwise_robustness(stl.arg)(x, t + t2)
for t2 in x[lo:hi].index), default=-oo)
@pointwise_robustness.register(stl.G)
def _(stl):
lo, hi = stl.interval
return lambda x, t: min((pointwise_robustness(stl.arg)(x, t + t2)
for t2 in x[lo:hi].index), default=oo)
@pointwise_robustness.register(stl.Neg)
def _(stl):
return lambda x, t: -pointwise_robustness(arg)(x, t)
op_lookup = {
">": sub,
">=": sub,
"<": lambda x, y: sub(y, x),
"<=": lambda x, y: sub(y, x),
"=": lambda a, b: -abs(a - b),
}
@pointwise_robustness.register(stl.LinEq)
def _(stl):
op = op_lookup[stl.op]
return lambda x, t: op(eval_terms(stl, x, t), stl.const)
def eval_terms(lineq, x, t):
psi = lens(lineq).terms.each_().modify(eval_term(x, t))
return sum(psi.terms)
def eval_term(x, t):
# TODO(lift interpolation much higher)
return lambda term: term.coeff*np.interp(t, x.index, x[term.id.name])

38
stl/synth.py Normal file
View file

@ -0,0 +1,38 @@
import operator as op
from stl.utils import set_params, param_lens
from stl.boolean_eval import pointwise_sat
from lenses import lens
def binsearch(stleval, *, tol=1e-3, lo, hi, polarity):
"""Only run search if tightest robustness was positive."""
# Early termination via bounds checks
if polarity and stleval(lo):
return lo
elif not polarity and stleval(hi):
return hi
while hi - lo > tol:
mid = lo + (hi - lo) / 2
r = stleval(mid)
lo, hi = (mid, hi) if r ^ polarity else (lo, mid)
# Want satisifiable formula
return hi if polarity else lo
def lex_param_project(stl, x, *, order, polarity, ranges, tol=1e-3):
val = {var: (ranges[var][0] if not polarity[var] else ranges[var][1]) for var in order}
p_lens = param_lens(stl)
def stleval_fact(var, val):
l = lens(val)[var]
return lambda p: pointwise_sat(set_params(stl, l.set(p)))(x, 0)
for var in order:
stleval = stleval_fact(var, val)
lo, hi = ranges[var]
param = binsearch(stleval, lo=lo, hi=hi, tol=tol, polarity=polarity[var])
val[var] = param
return val

22
stl/test_boolean_eval.py Normal file
View file

@ -0,0 +1,22 @@
import stl
import stl.boolean_eval
import pandas as pd
from nose2.tools import params
import unittest
from sympy import Symbol
ex1 = ("2*A > 3", False)
ex2 = ("F[0, 1](2*A > 3)", True)
ex3 = ("F[1, 0](2*A > 3)", False)
ex4 = ("G[1, 0](2*A > 3)", True)
ex5 = ("(A < 0)", False)
ex6 = ("G[0, 0.1](A < 0)", False)
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
columns=["A", "B"])
class TestSTLRobustness(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6)
def test_stl(self, phi_str, r):
phi = stl.parse(phi_str)
stl_eval = stl.boolean_eval.pointwise_sat(phi)
self.assertEqual(stl_eval(x, 0), r)

31
stl/test_parser.py Normal file
View file

@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
import stl
from nose2.tools import params
import unittest
from sympy import Symbol
ex1 = ('x1 > 2', stl.LinEq(
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),),
">",
2.0
))
ex1_ = ('x1 > a?', stl.LinEq(
(stl.Var(1, Symbol("x1"), stl.ast.t_sym),),
">",
Symbol("a?")
))
i1 = stl.Interval(0., 1.)
i1_ = stl.Interval(0., Symbol("b?"))
i2 = stl.Interval(2., 3.)
ex2 = ('◇[0,1](x1 > 2)', stl.F(i1, ex1[1]))
ex3 = ('□[2,3]◇[0,1](x1 > 2)', stl.G(i2, ex2[1]))
ex4 = ('(x1 > 2) or ((x1 > 2) or (x1 > 2))',
stl.Or((ex1[1], ex1[1], ex1[1])))
ex5 = ('G[0, b?](x1 > a?)',
stl.G(i1_, ex1_))
class TestSTLParser(unittest.TestCase):
@params(ex1, ex2, ex3, ex4)
def test_stl(self, phi_str, phi):
self.assertEqual(stl.parse(phi_str), phi)

25
stl/test_robustness.py Normal file
View file

@ -0,0 +1,25 @@
import stl
import stl.robustness
import pandas as pd
from nose2.tools import params
import unittest
from sympy import Symbol
oo = float('inf')
ex1 = ("2*A > 3", -1)
ex2 = ("F[0, 1](2*A > 3)", 5)
ex3 = ("F[1, 0](2*A > 3)", -oo)
ex4 = ("G[1, 0](2*A > 3)", oo)
ex5 = ("(A < 0)", -1)
ex6 = ("G[0, 0.1](A < 0)", -1)
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
columns=["A", "B"])
class TestSTLRobustness(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6)
def test_stl(self, phi_str, r):
phi = stl.parse(phi_str)
stl_eval = stl.robustness.pointwise_robustness(phi)
self.assertEqual(stl_eval(x, 0), r)

45
stl/test_synth.py Normal file
View file

@ -0,0 +1,45 @@
import stl
import stl.robustness
import stl.synth
import pandas as pd
from nose2.tools import params
import unittest
from sympy import Symbol
oo = float('inf')
ex1 = ("A > a?", ("a?",), {"a?": (0, 10)}, {"a?": False}, {"a?": 1})
ex2 = ("F[0, b?](A > a?)", ("a?", "b?"), {"a?": (0, 10), "b?": (0, 5)},
{"a?": False, "b?": True}, {"a?": 4, "b?": 0.2})
ex3 = ("F[0, b?](A < 0)", ("b?",), {"b?": (0, 5)},
{"b?": True}, {"b?": 5})
ex4 = ("G[0, b?](A < 0)", ("b?",), {"b?": (0.1, 5)},
{"b?": False}, {"b?": 0.1})
ex5 = ("F[0, b?](A > 0)", ("b?",), {"b?": (0.1, 5)},
{"b?": True}, {"b?": 0.1})
ex6 = ("(A > a?) or (A > b?)", ("a?", "b?",), {"a?": (0, 2), "b?": (0, 2)},
{"a?": False, "b?": False}, {"a?": 2, "b?": 1})
x = pd.DataFrame([[1,2], [1,4], [4,2]], index=[0,0.1,0.2],
columns=["A", "B"])
class TestSTLRobustness(unittest.TestCase):
@params(ex1, ex2, ex3, ex4, ex5, ex6)
def test_lex_synth(self, phi_str, order, ranges, polarity, val):
phi = stl.parse(phi_str)
val2 = stl.synth.lex_param_project(
phi, x, order=order, ranges=ranges, polarity=polarity)
phi2 = stl.utils.set_params(phi, val2)
phi3 = stl.utils.set_params(phi, val)
stl_eval = stl.robustness.pointwise_robustness(phi2)
stl_eval2 = stl.robustness.pointwise_robustness(phi3)
# check that the robustnesses are almost the same
self.assertAlmostEqual(stl_eval(x, 0), stl_eval2(x, 0), delta=0.01)
# check that the valuations are almost the same
for var in order:
self.assertAlmostEqual(val2[var], val[var], delta=0.01)

30
stl/test_utils.py Normal file
View file

@ -0,0 +1,30 @@
import stl
import stl.utils
import pandas as pd
from nose2.tools import params
import unittest
from sympy import Symbol
ex1 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"})
ex2 = ("G[0, c?](x > a?)", {"a?", "c?"})
ex3 = ("F[b?, 1]G[0, c?](x > a?)", {"a?", "b?", "c?"})
ex4 = ("F[b?, 1]G[0, c?](x > a?)", "F[2, 1]G[0, 3](x > 1)")
ex5 = ("G[0, c?](x > a?)", "G[0, 3](x > 1)")
val = {"a?": 1.0, "b?": 2.0, "c?": 3.0}
class TestSTLUtils(unittest.TestCase):
@params(ex1, ex2, ex3)
def test_param_lens(self, phi_str, params):
phi = stl.parse(phi_str)
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), params)
@params(ex4, ex5)
def test_set_params(self, phi_str, phi2_str):
phi = stl.parse(phi_str)
phi2 = stl.parse(phi2_str)
phi = stl.utils.set_params(phi, val)
self.assertEqual(set(map(str, stl.utils.param_lens(phi).get_all())), set())
self.assertEqual(phi, phi2)

76
stl/utils.py Normal file
View file

@ -0,0 +1,76 @@
from collections import deque
from lenses import lens, Lens
import funcy as fn
import sympy
from stl.ast import LinEq, And, Or, NaryOpSTL, F, G, Interval
def walk(stl, bfs=False):
"""Walks Ast. Defaults to DFS unless BFS flag is set."""
pop = deque.popleft if bfs else deque.pop
children = deque([stl])
while len(children) != 0:
node = pop(children)
yield node
children.extend(node.children())
def tree(stl):
return {x:set(x.children()) for x in walk(stl) if x.children()}
def type_pred(*args):
ast_types = set(args)
return lambda x: type(x) in ast_types
def _child_lens(psi, focus):
if psi is None:
return
if isinstance(psi, NaryOpSTL):
for j, _ in enumerate(psi.args):
yield focus.args[j]
else:
yield focus.arg
def ast_lens(phi:"STL", bind=True, *, pred, focus_lens=None) -> lens:
if focus_lens is None:
focus_lens = lambda x: [lens()]
tls = list(fn.flatten(_ast_lens(phi, pred=pred, focus_lens=focus_lens)))
tl = lens().tuple_(*tls).each_()
return tl.bind(phi) if bind else tl
def _ast_lens(phi, *, pred, focus=lens(), focus_lens):
psi = focus.get(state=phi)
ret_lens = [focus.add_lens(l) for l in focus_lens(psi)] if pred(psi) else []
if isinstance(psi, LinEq):
return ret_lens
child_lenses = list(_child_lens(psi, focus=focus))
ret_lens += [_ast_lens(phi, pred=pred, focus=cl, focus_lens=focus_lens)
for cl in child_lenses]
return ret_lens
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq))
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or))
def terms_lens(phi:"STL", bind=True) -> lens:
return lineq_lens(phi, bind).terms.each_()
def param_lens(phi):
is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf):
return [lens().const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
return ast_lens(phi, pred=type_pred(LinEq, F, G), focus_lens=focus_lens).filter_(is_sym)
def set_params(stl_or_lens, val):
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens)
return l.modify(lambda x: val[str(x)] if str(x) in val else x)