feat(pyargus): add interpolation method parameter for Signal construction

This commit is contained in:
Anand Balakrishnan 2023-10-03 19:46:17 -07:00
parent ad0e62eee5
commit 9ca6748c50
No known key found for this signature in database
3 changed files with 52 additions and 21 deletions

View file

@ -2,10 +2,10 @@ from __future__ import annotations
from typing import List, Optional, Tuple, Type, Union from typing import List, Optional, Tuple, Type, Union
from argus import _argus from . import _argus
from argus._argus import dtype from ._argus import dtype
from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt from .exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt
from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal from .signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
try: try:
__doc__ = _argus.__doc__ __doc__ = _argus.__doc__

View file

@ -1,4 +1,4 @@
from typing import ClassVar, Protocol, final from typing import ClassVar, Literal, Protocol, TypeAlias, final
from typing_extensions import Self from typing_extensions import Self
@ -146,36 +146,46 @@ class Signal:
@final @final
class BoolSignal(Signal): class BoolSignal(Signal):
@classmethod @classmethod
def constant(cls, value: bool) -> Self: ... def constant(cls, value: bool, *, interpolation_method: _InterpolationMethod = "linear") -> Self: ...
@classmethod @classmethod
def from_samples(cls, samples: list[tuple[float, bool]]) -> Self: ... def from_samples(
cls, samples: list[tuple[float, bool]], *, interpolation_method: _InterpolationMethod = "linear"
) -> Self: ...
def push(self, time: float, value: bool) -> None: ... def push(self, time: float, value: bool) -> None: ...
def at(self, time: float) -> bool | None: ... def at(self, time: float) -> bool | None: ...
_InterpolationMethod: TypeAlias = Literal["linear", "constant"]
@final @final
class IntSignal(Signal): class IntSignal(Signal):
@classmethod @classmethod
def constant(cls, value: int) -> Self: ... def constant(cls, value: int, *, interpolation_method: _InterpolationMethod = "linear") -> Self: ...
@classmethod @classmethod
def from_samples(cls, samples: list[tuple[float, int]]) -> Self: ... def from_samples(
cls, samples: list[tuple[float, int]], *, interpolation_method: _InterpolationMethod = "linear"
) -> Self: ...
def push(self, time: float, value: int) -> None: ... def push(self, time: float, value: int) -> None: ...
def at(self, time: float) -> int | None: ... def at(self, time: float) -> int | None: ...
@final @final
class UnsignedIntSignal(Signal): class UnsignedIntSignal(Signal):
@classmethod @classmethod
def constant(cls, value: int) -> Self: ... def constant(cls, value: int, *, interpolation_method: _InterpolationMethod = "linear") -> Self: ...
@classmethod @classmethod
def from_samples(cls, samples: list[tuple[float, int]]) -> Self: ... def from_samples(
cls, samples: list[tuple[float, int]], *, interpolation_method: _InterpolationMethod = "linear"
) -> Self: ...
def push(self, time: float, value: int) -> None: ... def push(self, time: float, value: int) -> None: ...
def at(self, time: float) -> int | None: ... def at(self, time: float) -> int | None: ...
@final @final
class FloatSignal(Signal): class FloatSignal(Signal):
@classmethod @classmethod
def constant(cls, value: float) -> Self: ... def constant(cls, value: float, *, interpolation_method: _InterpolationMethod = "linear") -> Self: ...
@classmethod @classmethod
def from_samples(cls, samples: list[tuple[float, float]]) -> Self: ... def from_samples(
cls, samples: list[tuple[float, float]], *, interpolation_method: _InterpolationMethod = "linear"
) -> Self: ...
def push(self, time: float, value: float) -> None: ... def push(self, time: float, value: float) -> None: ...
def at(self, time: float) -> float | None: ... def at(self, time: float) -> float | None: ...

View file

@ -1,5 +1,6 @@
use argus::signals::interpolation::Linear; use argus::signals::interpolation::{Constant, Linear};
use argus::signals::Signal; use argus::signals::Signal;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::PyType; use pyo3::types::PyType;
@ -10,6 +11,7 @@ use crate::{DType, PyArgusError};
pub enum PyInterp { pub enum PyInterp {
#[default] #[default]
Linear, Linear,
Constant,
} }
#[derive(Debug, Clone, derive_more::From, derive_more::TryInto)] #[derive(Debug, Clone, derive_more::From, derive_more::TryInto)]
@ -122,31 +124,49 @@ macro_rules! impl_signals {
impl [<$ty_name Signal>] { impl [<$ty_name Signal>] {
/// Create a new empty signal /// Create a new empty signal
#[new] #[new]
#[pyo3(signature = ())] #[pyo3(signature = (*, interpolation_method = "linear"))]
fn new() -> (Self, PySignal) { fn new(interpolation_method: &str) -> PyResult<(Self, PySignal)> {
(Self, PySignal::new(Signal::<$ty>::new(), PyInterp::Linear)) let interp = match interpolation_method {
"linear" => PyInterp::Linear,
"constant" => PyInterp::Constant,
_ => return Err(PyValueError::new_err(format!("unsupported interpolation method `{}`", interpolation_method))),
};
Ok((Self, PySignal::new(Signal::<$ty>::new(), interp)))
} }
/// Create a new signal with constant value /// Create a new signal with constant value
#[classmethod] #[classmethod]
fn constant(_: &PyType, py: Python<'_>, value: $ty) -> PyResult<Py<Self>> { #[pyo3(signature = (value, *, interpolation_method = "linear"))]
fn constant(_: &PyType, py: Python<'_>, value: $ty, interpolation_method: &str) -> PyResult<Py<Self>> {
let interp = match interpolation_method {
"linear" => PyInterp::Linear,
"constant" => PyInterp::Constant,
_ => return Err(PyValueError::new_err(format!("unsupported interpolation method `{}`", interpolation_method))),
};
Py::new( Py::new(
py, py,
(Self, PySignal::new(Signal::constant(value), PyInterp::Linear)) (Self, PySignal::new(Signal::constant(value), interp))
) )
} }
/// Create a new signal from some finite number of samples /// Create a new signal from some finite number of samples
#[classmethod] #[classmethod]
fn from_samples(_: &PyType, samples: Vec<(f64, $ty)>) -> PyResult<Py<Self>> { #[pyo3(signature = (samples, *, interpolation_method = "linear"))]
fn from_samples(_: &PyType, samples: Vec<(f64, $ty)>, interpolation_method: &str) -> PyResult<Py<Self>> {
let ret: Signal::<$ty> = Signal::<$ty>::try_from_iter(samples let ret: Signal::<$ty> = Signal::<$ty>::try_from_iter(samples
.into_iter() .into_iter()
.map(|(t, v)| (core::time::Duration::try_from_secs_f64(t).unwrap_or_else(|err| panic!("Value = {}, {}", t, err)), v)) .map(|(t, v)| (core::time::Duration::try_from_secs_f64(t).unwrap_or_else(|err| panic!("Value = {}, {}", t, err)), v))
).map_err(PyArgusError::from)?; ).map_err(PyArgusError::from)?;
let interp = match interpolation_method {
"linear" => PyInterp::Linear,
"constant" => PyInterp::Constant,
_ => return Err(PyValueError::new_err(format!("unsupported interpolation method `{}`", interpolation_method))),
};
Python::with_gil(|py| { Python::with_gil(|py| {
Py::new( Py::new(
py, py,
(Self, PySignal::new(ret, PyInterp::Linear)) (Self, PySignal::new(ret, interp))
) )
}) })
} }
@ -171,6 +191,7 @@ macro_rules! impl_signals {
let time = core::time::Duration::from_secs_f64(time); let time = core::time::Duration::from_secs_f64(time);
match super_.interpolation { match super_.interpolation {
PyInterp::Linear => signal.interpolate_at::<Linear>(time), PyInterp::Linear => signal.interpolate_at::<Linear>(time),
PyInterp::Constant => signal.interpolate_at::<Constant>(time),
} }
} }