formatting + delete dead code

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-15 01:05:19 -07:00
parent d78037816b
commit 8dbfc83083
3 changed files with 71 additions and 50 deletions

View file

@ -11,6 +11,7 @@ from sympy import Symbol
dt_sym = Symbol('dt', positive=True) dt_sym = Symbol('dt', positive=True)
t_sym = Symbol('t', positive=True) t_sym = Symbol('t', positive=True)
def flatten_binary(phi, op, dropT, shortT): def flatten_binary(phi, op, dropT, shortT):
f = lambda x: x.args if isinstance(x, op) else [x] f = lambda x: x.args if isinstance(x, op) else [x]
args = [arg for arg in phi.args if arg is not dropT] args = [arg for arg in phi.args if arg is not dropT]
@ -23,7 +24,7 @@ def flatten_binary(phi, op, dropT, shortT):
return args[0] return args[0]
else: else:
return op(tuple(fn.mapcat(f, phi.args))) return op(tuple(fn.mapcat(f, phi.args)))
class AST(object): class AST(object):
__slots__ = () __slots__ = ()
@ -44,7 +45,7 @@ class AST(object):
class _Top(AST): class _Top(AST):
__slots__ = () __slots__ = ()
def __repr__(self): def __repr__(self):
return "" return ""
@ -61,6 +62,7 @@ class _Bot(AST):
def __invert__(self): def __invert__(self):
return TOP return TOP
TOP = _Top() TOP = _Top()
BOT = _Bot() BOT = _Bot()
@ -70,7 +72,7 @@ class AtomicPred(namedtuple("AP", ["id"]), AST):
def __repr__(self): def __repr__(self):
return f"{self.id}" return f"{self.id}"
@property @property
def children(self): def children(self):
return set() return set()
@ -81,7 +83,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"]), AST):
def __repr__(self): def __repr__(self):
return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}" return " + ".join(map(str, self.terms)) + f" {self.op} {self.const}"
@property @property
def children(self): def children(self):
return set() return set()
@ -103,7 +105,7 @@ class Interval(namedtuple('I', ['lower', 'upper'])):
def __repr__(self): def __repr__(self):
return f"[{self.lower},{self.upper}]" return f"[{self.lower},{self.upper}]"
@property @property
def children(self): def children(self):
return {self.lower, self.upper} return {self.lower, self.upper}
@ -113,9 +115,10 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
__slots__ = () __slots__ = ()
OP = "?" OP = "?"
def __repr__(self): def __repr__(self):
return f" {self.OP} ".join(f"({x})" for x in self.args) return f" {self.OP} ".join(f"({x})" for x in self.args)
@property @property
def children(self): def children(self):
return set(self.args) return set(self.args)
@ -125,10 +128,12 @@ class Or(NaryOpSTL):
__slots__ = () __slots__ = ()
OP = "" OP = ""
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
return hash(repr(self)) return hash(repr(self))
class And(NaryOpSTL): class And(NaryOpSTL):
__slots__ = () __slots__ = ()
@ -144,7 +149,7 @@ class ModalOp(namedtuple('ModalOp', ['interval', 'arg']), AST):
def __repr__(self): def __repr__(self):
return f"{self.OP}{self.interval}({self.arg})" return f"{self.OP}{self.interval}({self.arg})"
@property @property
def children(self): def children(self):
return {self.arg} return {self.arg}
@ -158,6 +163,7 @@ class F(ModalOp):
# TODO: compute hash based on contents # TODO: compute hash based on contents
return hash(repr(self)) return hash(repr(self))
class G(ModalOp): class G(ModalOp):
__slots__ = () __slots__ = ()
OP = "" OP = ""
@ -172,7 +178,7 @@ class Until(namedtuple('ModalOp', ['arg1', 'arg2']), AST):
def __repr__(self): def __repr__(self):
return f"({self.arg1}) U ({self.arg2})" return f"({self.arg1}) U ({self.arg2})"
@property @property
def children(self): def children(self):
return {self.arg1, self.arg2} return {self.arg1, self.arg2}
@ -187,7 +193,7 @@ class Neg(namedtuple('Neg', ['arg']), AST):
def __repr__(self): def __repr__(self):
return f"¬({self.arg})" return f"¬({self.arg})"
@property @property
def children(self): def children(self):
return {self.arg} return {self.arg}
@ -202,7 +208,7 @@ class Next(namedtuple('Next', ['arg']), AST):
def __repr__(self): def __repr__(self):
return f"X({self.arg})" return f"X({self.arg})"
@property @property
def children(self): def children(self):
return {self.arg} return {self.arg}

View file

@ -62,7 +62,8 @@ _ = ~r"\s"+
__ = ~r"\s"* __ = ~r"\s"*
EOL = "\\n" EOL = "\\n"
''') ''')
class STLVisitor(NodeVisitor): class STLVisitor(NodeVisitor):
def __init__(self, H=float('inf')): def __init__(self, H=float('inf')):
super().__init__() super().__init__()
@ -78,7 +79,7 @@ class STLVisitor(NodeVisitor):
visit_paren_phi = partialmethod(children_getter, i=2) visit_paren_phi = partialmethod(children_getter, i=2)
def visit_interval(self, _, children): def visit_interval(self, _, children):
_, _, (left,), _, _, _, (right,), _, _ = children _, _, (left, ), _, _, _, (right, ), _, _ = children
left = left if left != [] else float("inf") left = left if left != [] else float("inf")
right = right if right != [] else float("inf") right = right if right != [] else float("inf")
if isinstance(left, int): if isinstance(left, int):
@ -92,7 +93,7 @@ class STLVisitor(NodeVisitor):
def visit_unbound(self, node, _): def visit_unbound(self, node, _):
return Symbol(node.text) return Symbol(node.text)
visit_op = get_text visit_op = get_text
def unary_temp_op_visitor(self, _, children, op): def unary_temp_op_visitor(self, _, children, op):
@ -101,13 +102,13 @@ class STLVisitor(NodeVisitor):
return op(i, phi) return op(i, phi)
def binop_visitor(self, _, children, op): def binop_visitor(self, _, children, op):
phi1, _, _, _, (phi2,) = children phi1, _, _, _, (phi2, ) = children
argL = list(phi1.args) if isinstance(phi1, op) else [phi1] argL = list(phi1.args) if isinstance(phi1, op) else [phi1]
argR = list(phi2.args) if isinstance(phi2, op) else [phi2] argR = list(phi2.args) if isinstance(phi2, op) else [phi2]
return op(tuple(argL + argR)) return op(tuple(argL + argR))
def sugar_binop_visitor(self, _, children, op): def sugar_binop_visitor(self, _, children, op):
phi1, _, _, _, (phi2,) = children phi1, _, _, _, (phi2, ) = children
return op(phi1, phi2) return op(phi1, phi2)
visit_f = partialmethod(unary_temp_op_visitor, op=ast.F) visit_f = partialmethod(unary_temp_op_visitor, op=ast.F)
@ -136,7 +137,7 @@ class STLVisitor(NodeVisitor):
coeffs, (iden, time) = children coeffs, (iden, time) = children
c = coeffs[0] if coeffs else Number(1) c = coeffs[0] if coeffs else Number(1)
return ast.Var(coeff=c, id=iden, time=time) return ast.Var(coeff=c, id=iden, time=time)
def visit_coeff(self, _, children): def visit_coeff(self, _, children):
dt, coeff, *_ = children[0] dt, coeff, *_ = children[0]
if not isinstance(dt, Symbol): if not isinstance(dt, Symbol):
@ -147,7 +148,7 @@ class STLVisitor(NodeVisitor):
def visit_terms(self, _, children): def visit_terms(self, _, children):
if isinstance(children[0], list): if isinstance(children[0], list):
term, _1, sgn ,_2, terms = children[0] term, _1, sgn, _2, terms = children[0]
terms = lens(terms)[0].coeff * sgn terms = lens(terms)[0].coeff * sgn
return [term] + terms return [term] + terms
else: else:
@ -170,5 +171,5 @@ class STLVisitor(NodeVisitor):
return ast.Next(children[1]) return ast.Next(children[1])
def parse(stl_str:str, rule:str="phi", H=float('inf')) -> "STL": def parse(stl_str: str, rule: str = "phi", H=float('inf')) -> "STL":
return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str)) return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str))

View file

@ -16,7 +16,8 @@ from stl.types import STL, STL_Generator, MTL
Lens = TypeVar('Lens') Lens = TypeVar('Lens')
def walk(phi:STL) -> STL_Generator:
def walk(phi: STL) -> STL_Generator:
"""Walk of the AST.""" """Walk of the AST."""
pop = deque.pop pop = deque.pop
children = deque([phi]) children = deque([phi])
@ -25,16 +26,18 @@ def walk(phi:STL) -> STL_Generator:
yield node yield node
children.extend(node.children) children.extend(node.children)
def vars_in_phi(phi): def vars_in_phi(phi):
focus = stl.terms_lens(phi) focus = stl.terms_lens(phi)
return set(focus.tuple_(lens.id, lens.time).get_all()) return set(focus.tuple_(lens.id, lens.time).get_all())
def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
def type_pred(*args: List[Type]) -> Mapping[Type, bool]:
ast_types = set(args) ast_types = set(args)
return lambda x: type(x) in ast_types return lambda x: type(x) in ast_types
def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens: def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
if focus_lens is None: if focus_lens is None:
focus_lens = lambda _: [lens] focus_lens = lambda _: [lens]
if pred is None: if pred is None:
@ -42,10 +45,11 @@ def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
l = lenses.bind(phi) if bind else lens l = lenses.bind(phi) if bind else lens
return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens)) return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens))
def _ast_lens(phi:STL, pred, focus_lens) -> Lens:
def _ast_lens(phi: STL, pred, focus_lens) -> Lens:
if pred(phi): if pred(phi):
yield from focus_lens(phi) yield from focus_lens(phi)
if phi is None or not phi.children: if phi is None or not phi.children:
return return
@ -54,36 +58,45 @@ def _ast_lens(phi:STL, pred, focus_lens) -> Lens:
elif isinstance(phi, stl.ast.Until): elif isinstance(phi, stl.ast.Until):
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')] child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
elif isinstance(phi, NaryOpSTL): elif isinstance(phi, NaryOpSTL):
child_lenses = [lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)] child_lenses = [
lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)
]
else: else:
child_lenses = [lens.GetAttr('arg')] child_lenses = [lens.GetAttr('arg')]
for l in child_lenses: for l in child_lenses:
yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)] yield from [l & cl for cl in _ast_lens(l.get()(phi), pred, focus_lens)]
lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq)) lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq))
AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred)) AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred))
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or)) and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or))
def terms_lens(phi:STL, bind:bool=True) -> Lens:
return lineq_lens(phi, bind).terms.each_() def terms_lens(phi: STL, bind: bool = True) -> Lens:
return lineq_lens(phi, bind).Each().terms.Each()
def param_lens(phi:STL) -> Lens: def param_lens(phi: STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol) is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf):
return [lens.const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
return ast_lens(phi, pred=type_pred(LinEq, F, G), def focus_lens(leaf):
focus_lens=focus_lens).filter_(is_sym) return [lens.const] if isinstance(leaf, LinEq) else [
lens.GetAttr('interval')[0],
lens.GetAttr('interval')[1]
]
return ast_lens(
phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym)
def set_params(stl_or_lens, val) -> STL: def set_params(stl_or_lens, val) -> STL:
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens) l = stl_or_lens if isinstance(stl_or_lens,
Lens) else param_lens(stl_or_lens)
return l.modify(lambda x: float(val.get(x, val.get(str(x), x)))) return l.modify(lambda x: float(val.get(x, val.get(str(x), x))))
def f_neg_or_canonical_form(phi:STL) -> STL: def f_neg_or_canonical_form(phi: STL) -> STL:
if isinstance(phi, LinEq): if isinstance(phi, LinEq):
return phi return phi
@ -106,24 +119,14 @@ def f_neg_or_canonical_form(phi:STL) -> STL:
raise NotImplementedError raise NotImplementedError
def to_mtl(phi:STL) -> MTL:
focus = lineq_lens(phi)
to_ap = lambda i: stl.ast.AtomicPred("AP{}".format(i))
ap_map = {to_ap(i): leq for i, leq in enumerate(focus.get_all())}
lineq_map = {v:k for k,v in ap_map.items()}
return focus.modify(lineq_map.get), ap_map
def from_mtl(phi:MTL, ap_map:Dict[AtomicPred, LinEq]) -> STL:
focus = AP_lens(phi)
return focus.modify(ap_map.get)
def _lineq_lipschitz(lineq): def _lineq_lipschitz(lineq):
return sum(map(abs, lens(lineq).terms.each_().coeff.get_all())) return sum(map(abs, lens(lineq).Each().terms.Each().coeff.collect()))
def linear_stl_lipschitz(phi): def linear_stl_lipschitz(phi):
"""Infinity norm lipschitz bound of linear inequality predicate.""" """Infinity norm lipschitz bound of linear inequality predicate."""
return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all()))) return float(max(map(_lineq_lipschitz, lineq_lens(phi).Each().collect())))
def inline_context(phi, context): def inline_context(phi, context):
phi2 = None phi2 = None
@ -133,6 +136,7 @@ def inline_context(phi, context):
# TODO: this is hack to flatten the AST. Fix! # TODO: this is hack to flatten the AST. Fix!
return stl.parse(str(phi)) return stl.parse(str(phi))
op_lookup = { op_lookup = {
">": op.gt, ">": op.gt,
">=": op.ge, ">=": op.ge,
@ -141,6 +145,7 @@ op_lookup = {
"=": op.eq, "=": op.eq,
} }
def get_times(x): def get_times(x):
times = set.union(*({t for t, _ in v.items()} for v in x.values())) times = set.union(*({t for t, _ in v.items()} for v in x.values()))
return sorted(times) return sorted(times)
@ -151,18 +156,19 @@ def eval_lineq(lineq, x, times=None, compact=True):
times = get_times(x) times = get_times(x)
def eval_term(term, t): def eval_term(term, t):
return float(term.coeff)*x[term.id.name][t] return float(term.coeff) * x[term.id.name][t]
output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1])) output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1]))
terms = lens(lineq).Each().terms.Each().collect() terms = lens(lineq).Each().terms.Each().collect()
for t in times: for t in times:
lhs = sum(eval_term(term, t) for term in terms) lhs = sum(eval_term(term, t) for term in terms)
output[t] = op_lookup[lineq.op](lhs, lineq.const) output[t] = op_lookup[lineq.op](lhs, lineq.const)
if compact: if compact:
output.compact() output.compact()
return output return output
def eval_lineqs(phi, x, times=None): def eval_lineqs(phi, x, times=None):
if times is None: if times is None:
times = get_times(x) times = get_times(x)
@ -172,26 +178,34 @@ def eval_lineqs(phi, x, times=None):
# EDSL # EDSL
def alw(phi, *, lo, hi): def alw(phi, *, lo, hi):
return G(Interval(lo, hi), phi) return G(Interval(lo, hi), phi)
def env(phi, *, lo, hi): def env(phi, *, lo, hi):
return F(Interval(lo, hi), phi) return F(Interval(lo, hi), phi)
def until(phi1, phi2, *, lo, hi): def until(phi1, phi2, *, lo, hi):
return stl.ast.Until(Interval(lo, hi), phi1, phi2) return stl.ast.Until(Interval(lo, hi), phi1, phi2)
def andf(*args): def andf(*args):
return reduce(op.and_, args) if args else stl.TOP return reduce(op.and_, args) if args else stl.TOP
def orf(*args): def orf(*args):
return reduce(op.or_, args) if args else stl.TOP return reduce(op.or_, args) if args else stl.TOP
def implies(x, y): def implies(x, y):
return ~x | y return ~x | y
def xor(x, y): def xor(x, y):
return (x | y) & ~(x & y) return (x | y) & ~(x & y)
def iff(x, y): def iff(x, y):
return (x & y) | (~x & ~y) return (x & y) | (~x & ~y)