Fixed PSTL construction

This commit is contained in:
Marcell Vazquez-Chanlatte 2017-10-26 16:14:45 -07:00
parent d5985406ad
commit eda63fd6f0
3 changed files with 38 additions and 30 deletions

View file

@ -97,7 +97,13 @@ class Var(namedtuple("Var", ["coeff", "id"])):
__slots__ = () __slots__ = ()
def __repr__(self): def __repr__(self):
return f"{self.coeff}*{self.id}" if self.coeff == -1:
coeff_str = "-"
elif self.coeff == +1:
coeff_str = ""
else:
coeff_str = f"{self.coeff}*"
return f"{coeff_str}{self.id}"
class Interval(namedtuple('I', ['lower', 'upper'])): class Interval(namedtuple('I', ['lower', 'upper'])):
@ -217,3 +223,15 @@ class Next(namedtuple('Next', ['arg']), AST):
def __hash__(self): def __hash__(self):
# TODO: compute hash based on contents # TODO: compute hash based on contents
return hash(repr(self)) return hash(repr(self))
class Param(namedtuple('Param', ['name']), AST):
__slots__ = ()
def __repr__(self):
return self.name
def __hash__(self):
# TODO: compute hash based on contents
return hash(repr(self))

View file

@ -12,8 +12,6 @@ from parsimonious import Grammar, NodeVisitor
from funcy import flatten from funcy import flatten
from lenses import lens from lenses import lens
from sympy import Symbol, Number
from stl import ast from stl import ast
from stl.utils import implies, xor, iff, env, alw from stl.utils import implies, xor, iff, env, alw
@ -45,8 +43,7 @@ interval = "[" __ const_or_unbound __ "," __ const_or_unbound __ "]"
const_or_unbound = unbound / "inf" / const const_or_unbound = unbound / "inf" / const
lineq = terms _ op _ const_or_unbound lineq = terms _ op _ const_or_unbound
term = coeff? var term = const? var
coeff = ((dt __ "*" __)? const __ "*" __) / (dt __ "*")
terms = (term __ pm __ terms) / term terms = (term __ pm __ terms) / term
var = id var = id
@ -56,7 +53,7 @@ pm = "+" / "-"
dt = "dt" dt = "dt"
unbound = id "?" unbound = id "?"
id = ~r"[a-zA-z\d]+" id = ~r"[a-zA-z\d]+"
const = ~r"[\+\-]?\d*(\.\d+)?" const = ~r"[-+]?\d*\.\d+|\d+"
op = ">=" / "<=" / "<" / ">" / "=" op = ">=" / "<=" / "<" / ">" / "="
_ = ~r"\s"+ _ = ~r"\s"+
__ = ~r"\s"* __ = ~r"\s"*
@ -64,8 +61,10 @@ EOL = "\\n"
''') ''')
oo = float('inf')
class STLVisitor(NodeVisitor): class STLVisitor(NodeVisitor):
def __init__(self, H=float('inf')): def __init__(self, H=oo):
super().__init__() super().__init__()
self.default_interval = ast.Interval(0.0, H) self.default_interval = ast.Interval(0.0, H)
@ -92,7 +91,7 @@ class STLVisitor(NodeVisitor):
return node.text return node.text
def visit_unbound(self, node, _): def visit_unbound(self, node, _):
return Symbol(node.text) return ast.Param(node.text)
visit_op = get_text visit_op = get_text
@ -128,23 +127,16 @@ class STLVisitor(NodeVisitor):
return env(psi, lo=lo, hi=hi) & alw(ast.Until(phi, psi), lo=0, hi=lo) return env(psi, lo=lo, hi=hi) & alw(ast.Until(phi, psi), lo=0, hi=lo)
def visit_id(self, name, _): def visit_id(self, name, _):
return Symbol(name.text) return name.text
def visit_const(self, const, children): def visit_const(self, const, children):
return float(const.text) return float(const.text)
def visit_term(self, _, children): def visit_term(self, _, children):
coeffs, (iden, time) = children coeffs, iden = children
c = coeffs[0] if coeffs else Number(1) c = coeffs[0] if coeffs else 1
return ast.Var(coeff=c, id=iden, time=time) return ast.Var(coeff=c, id=iden)
def visit_coeff(self, _, children):
dt, coeff, *_ = children[0]
if not isinstance(dt, Symbol):
dt = dt[0][0] if dt else Number(1)
return dt * coeff
else:
return dt
def visit_terms(self, _, children): def visit_terms(self, _, children):
if isinstance(children[0], list): if isinstance(children[0], list):
@ -159,7 +151,7 @@ class STLVisitor(NodeVisitor):
return ast.LinEq(tuple(terms), op, const[0]) return ast.LinEq(tuple(terms), op, const[0])
def visit_pm(self, node, _): def visit_pm(self, node, _):
return Number(1) if node.text == "+" else Number(-1) return 1 if node.text == "+" else -1
def visit_AP(self, *args): def visit_AP(self, *args):
return ast.AtomicPred(self.visit_id(*args)) return ast.AtomicPred(self.visit_id(*args))
@ -171,5 +163,5 @@ class STLVisitor(NodeVisitor):
return ast.Next(children[1]) return ast.Next(children[1])
def parse(stl_str: str, rule: str = "phi", H=float('inf')) -> "STL": def parse(stl_str: str, rule: str = "phi", H=oo) -> "STL":
return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str)) return STLVisitor(H).visit(STL_GRAMMAR[rule].parse(stl_str))

View file

@ -11,7 +11,7 @@ import traces
import stl.ast import stl.ast
from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg, from stl.ast import (LinEq, And, Or, NaryOpSTL, F, G, Interval, Neg,
AtomicPred) AtomicPred, Param, AST)
from stl.types import STL, STL_Generator, MTL from stl.types import STL, STL_Generator, MTL
Lens = TypeVar('Lens') Lens = TypeVar('Lens')
@ -43,7 +43,7 @@ def ast_lens(phi: STL, bind=True, *, pred=None, focus_lens=None) -> Lens:
if pred is None: if pred is None:
pred = lambda _: False pred = lambda _: False
l = lenses.bind(phi) if bind else lens l = lenses.bind(phi) if bind else lens
return l.Tuple(*_ast_lens(phi, pred=pred, focus_lens=focus_lens)) return l.Fork(*_ast_lens(phi, pred=pred, focus_lens=focus_lens))
def _ast_lens(phi: STL, pred, focus_lens) -> Lens: def _ast_lens(phi: STL, pred, focus_lens) -> Lens:
@ -77,22 +77,20 @@ def terms_lens(phi: STL, bind: bool = True) -> Lens:
def param_lens(phi: STL) -> Lens: def param_lens(phi: STL) -> Lens:
is_sym = lambda x: isinstance(x, sympy.Symbol)
def focus_lens(leaf): def focus_lens(leaf):
return [lens.const] if isinstance(leaf, LinEq) else [ candidates = [lens.const] if isinstance(leaf, LinEq) else [
lens.GetAttr('interval')[0], lens.GetAttr('interval')[0],
lens.GetAttr('interval')[1] lens.GetAttr('interval')[1]
] ]
return (x for x in candidates if isinstance(x.get()(leaf), Param))
return ast_lens( return ast_lens(
phi, pred=type_pred(LinEq, F, G), phi, pred=type_pred(LinEq, F, G),
focus_lens=focus_lens).filter_(is_sym) focus_lens=focus_lens)
def set_params(stl_or_lens, val) -> STL: def set_params(phi, val) -> STL:
l = stl_or_lens if isinstance(stl_or_lens, l = param_lens(phi) if isinstance(phi, AST) else phi
Lens) else param_lens(stl_or_lens)
return l.modify(lambda x: float(val.get(x, val.get(str(x), x)))) return l.modify(lambda x: float(val.get(x, val.get(str(x), x))))