Use sympy for time variable
This commit is contained in:
parent
2d7e033df0
commit
7a73554525
2 changed files with 28 additions and 8 deletions
30
stl.py
30
stl.py
|
|
@ -1,6 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# TODO: create lens convenience functions
|
||||
# x.lens().set()
|
||||
# TODO: create iso lens between sugar and non-sugar
|
||||
# TODO: supress + given a + (-b). i.e. want a - b
|
||||
|
||||
|
|
@ -10,10 +8,12 @@ from typing import Union
|
|||
from enum import Enum
|
||||
from sympy import Symbol
|
||||
|
||||
from lenses import lens
|
||||
|
||||
VarKind = Enum("VarKind", ["x", "u", "w"])
|
||||
str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w}
|
||||
dt = Symbol('dt', positive=True)
|
||||
|
||||
dt_sym = Symbol('dt', positive=True)
|
||||
t_sym = Symbol('t', positive=True)
|
||||
|
||||
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
|
||||
def __repr__(self):
|
||||
|
|
@ -30,7 +30,7 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
|
|||
|
||||
class Var(namedtuple("Var", ["kind", "id", "time"])):
|
||||
def __repr__(self):
|
||||
time_str = "[t + {}]".format(self.time)
|
||||
time_str = "[{}]".format(self.time)
|
||||
return "{k}{i}{t}".format(k=self.kind.name, i=self.id, t=time_str)
|
||||
|
||||
|
||||
|
|
@ -106,3 +106,23 @@ def walk(stl, bfs=False):
|
|||
|
||||
def tree(stl):
|
||||
return {x:set(x.children()) for x in walk(stl) if x.children()}
|
||||
|
||||
|
||||
def time_lens(phi:"STL") -> lens:
|
||||
return _time_lens(phi).bind(phi)
|
||||
|
||||
|
||||
def _time_lens(phi):
|
||||
if isinstance(phi, LinEq):
|
||||
return lens().terms.each_().var.time
|
||||
|
||||
if isinstance(phi, NaryOpSTL):
|
||||
child_lens = [lens()[i].add_lens(_time_lens(c)) for i, c
|
||||
in enumerate(phi.children())]
|
||||
return lens().args.tuple_(*child_lens).each_()
|
||||
else:
|
||||
return lens().arg.add_lens(_time_lens(phi.arg))
|
||||
|
||||
|
||||
def set_time(phi, *, t, dt=0.1):
|
||||
return time_lens(phi).call("evalf", subs={t_sym: t, dt_sym: dt})
|
||||
|
|
|
|||
|
|
@ -116,7 +116,7 @@ class STLVisitor(NodeVisitor):
|
|||
(var_kind, iden), time_node = children
|
||||
|
||||
time_node = list(flatten(time_node))
|
||||
time = time_node[0] if len(time_node) > 0 else 0
|
||||
time = time_node[0] if len(time_node) > 0 else stl.t_sym
|
||||
|
||||
return stl.Var(var_kind, iden, time)
|
||||
|
||||
|
|
@ -124,13 +124,13 @@ class STLVisitor(NodeVisitor):
|
|||
return children[3]* children[5]
|
||||
|
||||
def visit_prime(self, *_):
|
||||
return -stl.dt
|
||||
return -stl.dt_sym
|
||||
|
||||
def visit_const(self, const, children):
|
||||
return float(const.text)
|
||||
|
||||
def visit_dt(self, *_):
|
||||
return stl.dt
|
||||
return stl.dt_sym
|
||||
|
||||
def visit_term(self, _, children):
|
||||
coeffs, var = children
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue