From 7a73554525e2e450aa3a9d82cf4c246538e115eb Mon Sep 17 00:00:00 2001 From: Marcell Vazquez-Chanlatte Date: Fri, 8 Jul 2016 20:28:28 -0700 Subject: [PATCH] Use sympy for time variable --- stl.py | 30 +++++++++++++++++++++++++----- stl_parser.py | 6 +++--- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/stl.py b/stl.py index e73cb60..553a6c8 100644 --- a/stl.py +++ b/stl.py @@ -1,6 +1,4 @@ # -*- coding: utf-8 -*- -# TODO: create lens convenience functions -# x.lens().set() # TODO: create iso lens between sugar and non-sugar # TODO: supress + given a + (-b). i.e. want a - b @@ -10,10 +8,12 @@ from typing import Union from enum import Enum from sympy import Symbol +from lenses import lens + VarKind = Enum("VarKind", ["x", "u", "w"]) str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w} -dt = Symbol('dt', positive=True) - +dt_sym = Symbol('dt', positive=True) +t_sym = Symbol('t', positive=True) class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])): def __repr__(self): @@ -30,7 +30,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])): class Var(namedtuple("Var", ["kind", "id", "time"])): def __repr__(self): - time_str = "[t + {}]".format(self.time) + time_str = "[{}]".format(self.time) return "{k}{i}{t}".format(k=self.kind.name, i=self.id, t=time_str) @@ -106,3 +106,23 @@ def walk(stl, bfs=False): def tree(stl): return {x:set(x.children()) for x in walk(stl) if x.children()} + + +def time_lens(phi:"STL") -> lens: + return _time_lens(phi).bind(phi) + + +def _time_lens(phi): + if isinstance(phi, LinEq): + return lens().terms.each_().var.time + + if isinstance(phi, NaryOpSTL): + child_lens = [lens()[i].add_lens(_time_lens(c)) for i, c + in enumerate(phi.children())] + return lens().args.tuple_(*child_lens).each_() + else: + return lens().arg.add_lens(_time_lens(phi.arg)) + + +def set_time(phi, *, t, dt=0.1): + return time_lens(phi).call("evalf", subs={t_sym: t, dt_sym: dt}) diff --git a/stl_parser.py b/stl_parser.py index e328852..c02cf10 100644 --- a/stl_parser.py +++ b/stl_parser.py @@ -116,7 +116,7 @@ class STLVisitor(NodeVisitor): (var_kind, iden), time_node = children time_node = list(flatten(time_node)) - time = time_node[0] if len(time_node) > 0 else 0 + time = time_node[0] if len(time_node) > 0 else stl.t_sym return stl.Var(var_kind, iden, time) @@ -124,13 +124,13 @@ class STLVisitor(NodeVisitor): return children[3]* children[5] def visit_prime(self, *_): - return -stl.dt + return -stl.dt_sym def visit_const(self, const, children): return float(const.text) def visit_dt(self, *_): - return stl.dt + return stl.dt_sym def visit_term(self, _, children): coeffs, var = children