fix parsing +/- and added symbolic dt

This commit is contained in:
Marcell Vazquez-Chanlatte 2016-07-05 00:51:23 -07:00
parent 3646a0e2cb
commit 2d7e033df0
2 changed files with 28 additions and 18 deletions

10
stl.py
View file

@ -8,9 +8,12 @@ from collections import namedtuple, deque
from itertools import repeat from itertools import repeat
from typing import Union from typing import Union
from enum import Enum from enum import Enum
from sympy import Symbol
VarKind = Enum("VarKind", ["x", "u", "w"]) VarKind = Enum("VarKind", ["x", "u", "w"])
str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w} str_to_varkind = {"x": VarKind.x, "u": VarKind.u, "w": VarKind.w}
dt = Symbol('dt', positive=True)
class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])): class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
def __repr__(self): def __repr__(self):
@ -27,15 +30,14 @@ class LinEq(namedtuple("LinEquality", ["terms", "op", "const"])):
class Var(namedtuple("Var", ["kind", "id", "time"])): class Var(namedtuple("Var", ["kind", "id", "time"])):
def __repr__(self): def __repr__(self):
time_str = "'" if self.time == -1 else "[t+{}]".format(self.time) time_str = "[t + {}]".format(self.time)
return "{k}{i}{t}".format(k=self.kind.name, i=self.id, t=time_str) return "{k}{i}{t}".format(k=self.kind.name, i=self.id, t=time_str)
class Term(namedtuple("Term", ["dt", "coeff", "var"])): class Term(namedtuple("Term", ["coeff", "var"])):
def __repr__(self): def __repr__(self):
dt = "dt*" if self.dt else ""
coeff = str(self.coeff) + "*" if self.coeff != 1 else "" coeff = str(self.coeff) + "*" if self.coeff != 1 else ""
return "{dt}{c}{v}".format(dt=dt, c=coeff, v=self.var) return "{c}{v}".format(c=coeff, v=self.var)
class Interval(namedtuple('I', ['lower', 'upper'])): class Interval(namedtuple('I', ['lower', 'upper'])):

View file

@ -1,13 +1,14 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# TODO: consider using sympy to interpret stuff
# TODO: break out into seperate library # TODO: break out into seperate library
# TODO: allow matrix to be symbolically parsed in STL_GRAMMAR # TODO: allow matrix to be symbolically parsed in STL_GRAMMAR
# TODO: allow multiplication to be distributive # TODO: allow multiplication to be distributive
# TODO: support reference specific time points # TODO: support reference specific time points
# TODO: add Implies and Iff syntactic sugar # TODO: add Implies and Iff syntactic sugar
# TODO: add support for parsing Until # TODO: add support for parsing Until
# TODO: properly handle pm when parsing
# TODO: support variables on both sides of ineq # TODO: support variables on both sides of ineq
# TODO: Allow -x = -1*x
# TODO: change way of parsing dt # TODO: change way of parsing dt
# - Allow inside of time index # - Allow inside of time index
# - Allow dt*x rather than dt*1*x # - Allow dt*x rather than dt*1*x
@ -19,6 +20,7 @@ import operator as op
from parsimonious import Grammar, NodeVisitor from parsimonious import Grammar, NodeVisitor
from funcy import flatten from funcy import flatten
import numpy as np import numpy as np
from lenses import lens
from stl import stl from stl import stl
@ -39,19 +41,14 @@ interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]"
const_or_unbound = unbound / const const_or_unbound = unbound / const
symlineq = matvec (__ "+" __ dt "*" matvec)? _ op _ const_or_unbound
matvec = mat "*" vec
mat = "A" ~r"\d+"
vec = ("X" / "U" / "W") ("'")?
lineq = terms _ op _ const_or_unbound lineq = terms _ op _ const_or_unbound
term = coeff? var term = coeff? var
coeff = (dt __ "*" __)? const __ "*" __ coeff = (dt __ "*" __)? const __ "*" __
terms = (term __ ("+"/"-") __ terms) / term terms = (term __ pm __ terms) / term
var = id time? var = id time?
time = (prime / ("[" "t" __ pm __ const "]")) time = prime / time_index
time_index = "[" "t" __ pm __ const "]"
prime = "'" prime = "'"
pm = "+" / "-" pm = "+" / "-"
@ -123,24 +120,32 @@ class STLVisitor(NodeVisitor):
return stl.Var(var_kind, iden, time) return stl.Var(var_kind, iden, time)
def visit_time_index(self, _, children):
return children[3]* children[5]
def visit_prime(self, *_): def visit_prime(self, *_):
return -1 return -stl.dt
def visit_const(self, const, children): def visit_const(self, const, children):
return float(const.text) return float(const.text)
def visit_dt(self, *_):
return stl.dt
def visit_term(self, _, children): def visit_term(self, _, children):
coeffs, var = children coeffs, var = children
(dt, c) = coeffs[0] if coeffs else (False, 1) c = coeffs[0] if coeffs else 1
return stl.Term(dt, c, var) return stl.Term(c, var)
def visit_coeff(self, _, children): def visit_coeff(self, _, children):
dt, coeff, *_ = children dt, coeff, *_ = children
return bool(dt), coeff dt = dt[0][0] if dt else 1
return dt * coeff
def visit_terms(self, _, children): def visit_terms(self, _, children):
if isinstance(children[0], list): if isinstance(children[0], list):
term, *_, terms = children[0] term, _1, sgn ,_2, terms = children[0]
terms = lens(terms)[0].coeff * sgn
return [term] + terms return [term] + terms
else: else:
return children return children
@ -149,6 +154,9 @@ class STLVisitor(NodeVisitor):
terms, _1, op, _2, const = children terms, _1, op, _2, const = children
return stl.LinEq(tuple(terms), op, const[0]) return stl.LinEq(tuple(terms), op, const[0])
def visit_pm(self, node, _):
return 1 if node.text == "+" else -1
class MatrixVisitor(NodeVisitor): class MatrixVisitor(NodeVisitor):
def generic_visit(self, _, children): def generic_visit(self, _, children):