formatting + delete dead code
This commit is contained in:
parent
d78037816b
commit
8dbfc83083
3 changed files with 71 additions and 50 deletions
|
|
@ -11,6 +11,7 @@ from sympy import Symbol
|
||||||
dt_sym = Symbol('dt', positive=True)
|
dt_sym = Symbol('dt', positive=True)
|
||||||
t_sym = Symbol('t', positive=True)
|
t_sym = Symbol('t', positive=True)
|
||||||
|
|
||||||
|
|
||||||
def flatten_binary(phi, op, dropT, shortT):
|
def flatten_binary(phi, op, dropT, shortT):
|
||||||
f = lambda x: x.args if isinstance(x, op) else [x]
|
f = lambda x: x.args if isinstance(x, op) else [x]
|
||||||
args = [arg for arg in phi.args if arg is not dropT]
|
args = [arg for arg in phi.args if arg is not dropT]
|
||||||
|
|
@ -61,6 +62,7 @@ class _Bot(AST):
|
||||||
def __invert__(self):
|
def __invert__(self):
|
||||||
return TOP
|
return TOP
|
||||||
|
|
||||||
|
|
||||||
TOP = _Top()
|
TOP = _Top()
|
||||||
BOT = _Bot()
|
BOT = _Bot()
|
||||||
|
|
||||||
|
|
@ -113,6 +115,7 @@ class NaryOpSTL(namedtuple('NaryOp', ['args']), AST):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
OP = "?"
|
OP = "?"
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f" {self.OP} ".join(f"({x})" for x in self.args)
|
return f" {self.OP} ".join(f"({x})" for x in self.args)
|
||||||
|
|
||||||
|
|
@ -125,10 +128,12 @@ class Or(NaryOpSTL):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
OP = "∨"
|
OP = "∨"
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
# TODO: compute hash based on contents
|
# TODO: compute hash based on contents
|
||||||
return hash(repr(self))
|
return hash(repr(self))
|
||||||
|
|
||||||
|
|
||||||
class And(NaryOpSTL):
|
class And(NaryOpSTL):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
|
|
@ -158,6 +163,7 @@ class F(ModalOp):
|
||||||
# TODO: compute hash based on contents
|
# TODO: compute hash based on contents
|
||||||
return hash(repr(self))
|
return hash(repr(self))
|
||||||
|
|
||||||
|
|
||||||
class G(ModalOp):
|
class G(ModalOp):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
OP = "□"
|
OP = "□"
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,7 @@ __ = ~r"\s"*
|
||||||
EOL = "\\n"
|
EOL = "\\n"
|
||||||
''')
|
''')
|
||||||
|
|
||||||
|
|
||||||
class STLVisitor(NodeVisitor):
|
class STLVisitor(NodeVisitor):
|
||||||
def __init__(self, H=float('inf')):
|
def __init__(self, H=float('inf')):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -78,7 +79,7 @@ class STLVisitor(NodeVisitor):
|
||||||
visit_paren_phi = partialmethod(children_getter, i=2)
|
visit_paren_phi = partialmethod(children_getter, i=2)
|
||||||
|
|
||||||
def visit_interval(self, _, children):
|
def visit_interval(self, _, children):
|
||||||
_, _, (left,), _, _, _, (right,), _, _ = children
|
_, _, (left, ), _, _, _, (right, ), _, _ = children
|
||||||
left = left if left != [] else float("inf")
|
left = left if left != [] else float("inf")
|
||||||
right = right if right != [] else float("inf")
|
right = right if right != [] else float("inf")
|
||||||
if isinstance(left, int):
|
if isinstance(left, int):
|
||||||
|
|
@ -101,13 +102,13 @@ class STLVisitor(NodeVisitor):
|
||||||
return op(i, phi)
|
return op(i, phi)
|
||||||
|
|
||||||
def binop_visitor(self, _, children, op):
|
def binop_visitor(self, _, children, op):
|
||||||
phi1, _, _, _, (phi2,) = children
|
phi1, _, _, _, (phi2, ) = children
|
||||||
argL = list(phi1.args) if isinstance(phi1, op) else [phi1]
|
argL = list(phi1.args) if isinstance(phi1, op) else [phi1]
|
||||||
argR = list(phi2.args) if isinstance(phi2, op) else [phi2]
|
argR = list(phi2.args) if isinstance(phi2, op) else [phi2]
|
||||||
return op(tuple(argL + argR))
|
return op(tuple(argL + argR))
|
||||||
|
|
||||||
def sugar_binop_visitor(self, _, children, op):
|
def sugar_binop_visitor(self, _, children, op):
|
||||||
phi1, _, _, _, (phi2,) = children
|
phi1, _, _, _, (phi2, ) = children
|
||||||
return op(phi1, phi2)
|
return op(phi1, phi2)
|
||||||
|
|
||||||
visit_f = partialmethod(unary_temp_op_visitor, op=ast.F)
|
visit_f = partialmethod(unary_temp_op_visitor, op=ast.F)
|
||||||
|
|
@ -147,7 +148,7 @@ class STLVisitor(NodeVisitor):
|
||||||
|
|
||||||
def visit_terms(self, _, children):
|
def visit_terms(self, _, children):
|
||||||
if isinstance(children[0], list):
|
if isinstance(children[0], list):
|
||||||
term, _1, sgn ,_2, terms = children[0]
|
term, _1, sgn, _2, terms = children[0]
|
||||||
terms = lens(terms)[0].coeff * sgn
|
terms = lens(terms)[0].coeff * sgn
|
||||||
return [term] + terms
|
return [term] + terms
|
||||||
else:
|
else:
|
||||||
|
|
@ -170,5 +171,5 @@ class STLVisitor(NodeVisitor):
|
||||||
return ast.Next(children[1])
|
return ast.Next(children[1])
|
||||||
|
|
||||||
|
|
||||||
def parse(stl_str:str, rule:str="phi", H=float('inf')) -> "STL":
|
def parse(stl_str: str, rule: str = "phi", H=float('inf')) -> "STL":
|
||||||
return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str))
|
return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str))
|
||||||
|
|
|
||||||
70
stl/utils.py
70
stl/utils.py
|
|
@ -16,7 +16,8 @@ from stl.types import STL, STL_Generator, MTL
|
||||||
|
|
||||||
Lens = TypeVar('Lens')
|
Lens = TypeVar('Lens')
|
||||||
|
|
||||||
def walk(phi:STL) -> STL_Generator:
|
|
||||||
|
def walk(phi: STL) -> STL_Generator:
|
||||||
"""Walk of the AST."""
|
"""Walk of the AST."""
|
||||||
pop = deque.pop
|
pop = deque.pop
|
||||||
children = deque([phi])
|
children = deque([phi])
|
||||||
|
|
@ -25,16 +26,18 @@ def walk(phi:STL) -> STL_Generator:
|
||||||
yield node
|
yield node
|
||||||
children.extend(node.children)
|
children.extend(node.children)
|
||||||
|
|
||||||
|
|
||||||
def vars_in_phi(phi):
|
def vars_in_phi(phi):
|
||||||
focus = stl.terms_lens(phi)
|
focus = stl.terms_lens(phi)
|
||||||
return set(focus.tuple_(lens.id, lens.time).get_all())
|
return set(focus.tuple_(lens.id, lens.time).get_all())
|
||||||
|
|
||||||
def type_pred(*args:List[Type]) -> Mapping[Type, bool]:
|
|
||||||
|
def type_pred(*args: List[Type]) -> Mapping[Type, bool]:
|
||||||
ast_types = set(args)
|
ast_types = set(args)
|
||||||
return lambda x: type(x) in ast_types
|
return lambda x: type(x) in ast_types
|
||||||
|
|
||||||
|
|
||||||
def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
|
def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
|
||||||
if focus_lens is None:
|
if focus_lens is None:
|
||||||
focus_lens = lambda _: [lens]
|
focus_lens = lambda _: [lens]
|
||||||
if pred is None:
|
if pred is None:
|
||||||
|
|
@ -42,7 +45,8 @@ def ast_lens(phi:STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
|
||||||
l = lenses.bind(phi) if bind else lens
|
l = lenses.bind(phi) if bind else lens
|
||||||
return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens))
|
return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens))
|
||||||
|
|
||||||
def _ast_lens(phi:STL, pred, focus_lens) -> Lens:
|
|
||||||
|
def _ast_lens(phi: STL, pred, focus_lens) -> Lens:
|
||||||
if pred(phi):
|
if pred(phi):
|
||||||
yield from focus_lens(phi)
|
yield from focus_lens(phi)
|
||||||
|
|
||||||
|
|
@ -54,7 +58,9 @@ def _ast_lens(phi:STL, pred, focus_lens) -> Lens:
|
||||||
elif isinstance(phi, stl.ast.Until):
|
elif isinstance(phi, stl.ast.Until):
|
||||||
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
|
child_lenses = [lens.GetAttr('arg1'), lens.GetAttr('arg2')]
|
||||||
elif isinstance(phi, NaryOpSTL):
|
elif isinstance(phi, NaryOpSTL):
|
||||||
child_lenses = [lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)]
|
child_lenses = [
|
||||||
|
lens.GetAttr('args')[j] for j, _ in enumerate(phi.args)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
child_lenses = [lens.GetAttr('arg')]
|
child_lenses = [lens.GetAttr('arg')]
|
||||||
for l in child_lenses:
|
for l in child_lenses:
|
||||||
|
|
@ -65,25 +71,32 @@ lineq_lens = fn.partial(ast_lens, pred=type_pred(LinEq))
|
||||||
AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred))
|
AP_lens = fn.partial(ast_lens, pred=type_pred(stl.ast.AtomicPred))
|
||||||
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or))
|
and_or_lens = fn.partial(ast_lens, pred=type_pred(And, Or))
|
||||||
|
|
||||||
def terms_lens(phi:STL, bind:bool=True) -> Lens:
|
|
||||||
return lineq_lens(phi, bind).terms.each_()
|
def terms_lens(phi: STL, bind: bool = True) -> Lens:
|
||||||
|
return lineq_lens(phi, bind).Each().terms.Each()
|
||||||
|
|
||||||
|
|
||||||
def param_lens(phi:STL) -> Lens:
|
def param_lens(phi: STL) -> Lens:
|
||||||
is_sym = lambda x: isinstance(x, sympy.Symbol)
|
is_sym = lambda x: isinstance(x, sympy.Symbol)
|
||||||
def focus_lens(leaf):
|
|
||||||
return [lens.const] if isinstance(leaf, LinEq) else [lens().interval[0], lens().interval[1]]
|
|
||||||
|
|
||||||
return ast_lens(phi, pred=type_pred(LinEq, F, G),
|
def focus_lens(leaf):
|
||||||
|
return [lens.const] if isinstance(leaf, LinEq) else [
|
||||||
|
lens.GetAttr('interval')[0],
|
||||||
|
lens.GetAttr('interval')[1]
|
||||||
|
]
|
||||||
|
|
||||||
|
return ast_lens(
|
||||||
|
phi, pred=type_pred(LinEq, F, G),
|
||||||
focus_lens=focus_lens).filter_(is_sym)
|
focus_lens=focus_lens).filter_(is_sym)
|
||||||
|
|
||||||
|
|
||||||
def set_params(stl_or_lens, val) -> STL:
|
def set_params(stl_or_lens, val) -> STL:
|
||||||
l = stl_or_lens if isinstance(stl_or_lens, Lens) else param_lens(stl_or_lens)
|
l = stl_or_lens if isinstance(stl_or_lens,
|
||||||
|
Lens) else param_lens(stl_or_lens)
|
||||||
return l.modify(lambda x: float(val.get(x, val.get(str(x), x))))
|
return l.modify(lambda x: float(val.get(x, val.get(str(x), x))))
|
||||||
|
|
||||||
|
|
||||||
def f_neg_or_canonical_form(phi:STL) -> STL:
|
def f_neg_or_canonical_form(phi: STL) -> STL:
|
||||||
if isinstance(phi, LinEq):
|
if isinstance(phi, LinEq):
|
||||||
return phi
|
return phi
|
||||||
|
|
||||||
|
|
@ -106,24 +119,14 @@ def f_neg_or_canonical_form(phi:STL) -> STL:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def to_mtl(phi:STL) -> MTL:
|
|
||||||
focus = lineq_lens(phi)
|
|
||||||
to_ap = lambda i: stl.ast.AtomicPred("AP{}".format(i))
|
|
||||||
ap_map = {to_ap(i): leq for i, leq in enumerate(focus.get_all())}
|
|
||||||
lineq_map = {v:k for k,v in ap_map.items()}
|
|
||||||
return focus.modify(lineq_map.get), ap_map
|
|
||||||
|
|
||||||
|
|
||||||
def from_mtl(phi:MTL, ap_map:Dict[AtomicPred, LinEq]) -> STL:
|
|
||||||
focus = AP_lens(phi)
|
|
||||||
return focus.modify(ap_map.get)
|
|
||||||
|
|
||||||
def _lineq_lipschitz(lineq):
|
def _lineq_lipschitz(lineq):
|
||||||
return sum(map(abs, lens(lineq).terms.each_().coeff.get_all()))
|
return sum(map(abs, lens(lineq).Each().terms.Each().coeff.collect()))
|
||||||
|
|
||||||
|
|
||||||
def linear_stl_lipschitz(phi):
|
def linear_stl_lipschitz(phi):
|
||||||
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
"""Infinity norm lipschitz bound of linear inequality predicate."""
|
||||||
return float(max(map(_lineq_lipschitz, lineq_lens(phi).get_all())))
|
return float(max(map(_lineq_lipschitz, lineq_lens(phi).Each().collect())))
|
||||||
|
|
||||||
|
|
||||||
def inline_context(phi, context):
|
def inline_context(phi, context):
|
||||||
phi2 = None
|
phi2 = None
|
||||||
|
|
@ -133,6 +136,7 @@ def inline_context(phi, context):
|
||||||
# TODO: this is hack to flatten the AST. Fix!
|
# TODO: this is hack to flatten the AST. Fix!
|
||||||
return stl.parse(str(phi))
|
return stl.parse(str(phi))
|
||||||
|
|
||||||
|
|
||||||
op_lookup = {
|
op_lookup = {
|
||||||
">": op.gt,
|
">": op.gt,
|
||||||
">=": op.ge,
|
">=": op.ge,
|
||||||
|
|
@ -141,6 +145,7 @@ op_lookup = {
|
||||||
"=": op.eq,
|
"=": op.eq,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_times(x):
|
def get_times(x):
|
||||||
times = set.union(*({t for t, _ in v.items()} for v in x.values()))
|
times = set.union(*({t for t, _ in v.items()} for v in x.values()))
|
||||||
return sorted(times)
|
return sorted(times)
|
||||||
|
|
@ -151,7 +156,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
|
||||||
times = get_times(x)
|
times = get_times(x)
|
||||||
|
|
||||||
def eval_term(term, t):
|
def eval_term(term, t):
|
||||||
return float(term.coeff)*x[term.id.name][t]
|
return float(term.coeff) * x[term.id.name][t]
|
||||||
|
|
||||||
output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1]))
|
output = traces.TimeSeries(domain=traces.Domain(times[0], times[-1]))
|
||||||
terms = lens(lineq).Each().terms.Each().collect()
|
terms = lens(lineq).Each().terms.Each().collect()
|
||||||
|
|
@ -163,6 +168,7 @@ def eval_lineq(lineq, x, times=None, compact=True):
|
||||||
output.compact()
|
output.compact()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def eval_lineqs(phi, x, times=None):
|
def eval_lineqs(phi, x, times=None):
|
||||||
if times is None:
|
if times is None:
|
||||||
times = get_times(x)
|
times = get_times(x)
|
||||||
|
|
@ -172,26 +178,34 @@ def eval_lineqs(phi, x, times=None):
|
||||||
|
|
||||||
# EDSL
|
# EDSL
|
||||||
|
|
||||||
|
|
||||||
def alw(phi, *, lo, hi):
|
def alw(phi, *, lo, hi):
|
||||||
return G(Interval(lo, hi), phi)
|
return G(Interval(lo, hi), phi)
|
||||||
|
|
||||||
|
|
||||||
def env(phi, *, lo, hi):
|
def env(phi, *, lo, hi):
|
||||||
return F(Interval(lo, hi), phi)
|
return F(Interval(lo, hi), phi)
|
||||||
|
|
||||||
|
|
||||||
def until(phi1, phi2, *, lo, hi):
|
def until(phi1, phi2, *, lo, hi):
|
||||||
return stl.ast.Until(Interval(lo, hi), phi1, phi2)
|
return stl.ast.Until(Interval(lo, hi), phi1, phi2)
|
||||||
|
|
||||||
|
|
||||||
def andf(*args):
|
def andf(*args):
|
||||||
return reduce(op.and_, args) if args else stl.TOP
|
return reduce(op.and_, args) if args else stl.TOP
|
||||||
|
|
||||||
|
|
||||||
def orf(*args):
|
def orf(*args):
|
||||||
return reduce(op.or_, args) if args else stl.TOP
|
return reduce(op.or_, args) if args else stl.TOP
|
||||||
|
|
||||||
|
|
||||||
def implies(x, y):
|
def implies(x, y):
|
||||||
return ~x | y
|
return ~x | y
|
||||||
|
|
||||||
|
|
||||||
def xor(x, y):
|
def xor(x, y):
|
||||||
return (x | y) & ~(x & y)
|
return (x | y) & ~(x & y)
|
||||||
|
|
||||||
|
|
||||||
def iff(x, y):
|
def iff(x, y):
|
||||||
return (x & y) | (~x & ~y)
|
return (x & y) | (~x & ~y)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue