feat!: make interpolation method explicit

All methods that need to perform interpolation of some sort need an
explicit interpolation method. In Rust, this manifests as a generic
parameter, while in Python, this is a string parameter.
This commit is contained in:
Anand Balakrishnan 2023-10-04 14:42:51 -07:00
parent e2cff9449e
commit 50d5a0a78a
8 changed files with 221 additions and 296 deletions

View file

@ -1,5 +1,7 @@
use std::collections::HashMap;
use std::str::FromStr;
use argus::signals::interpolation::{Constant, Linear};
use argus::{AnySignal, BooleanSemantics, QuantitativeSemantics, Signal, Trace};
use pyo3::exceptions::PyTypeError;
use pyo3::prelude::*;
@ -60,14 +62,26 @@ impl Trace for PyTrace {
}
#[pyfunction]
fn eval_bool_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<BoolSignal>> {
let sig = BooleanSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?;
Python::with_gil(|py| Py::new(py, (BoolSignal, PySignal::new(sig, PyInterp::Linear))))
#[pyo3(signature = (expr, trace, *, interpolation_method = "linear"))]
fn eval_bool_semantics(expr: &PyBoolExpr, trace: &PyTrace, interpolation_method: &str) -> PyResult<Py<BoolSignal>> {
let interp = PyInterp::from_str(interpolation_method)?;
let sig = match interp {
PyInterp::Linear => BooleanSemantics::eval::<Linear, Linear>(&expr.0, trace).map_err(PyArgusError::from)?,
PyInterp::Constant => {
BooleanSemantics::eval::<Constant, Constant>(&expr.0, trace).map_err(PyArgusError::from)?
}
};
Python::with_gil(|py| Py::new(py, (BoolSignal, PySignal::new(sig, interp))))
}
#[pyfunction]
fn eval_robust_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<FloatSignal>> {
let sig = QuantitativeSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?;
Python::with_gil(|py| Py::new(py, (FloatSignal, PySignal::new(sig, PyInterp::Linear))))
#[pyo3(signature = (expr, trace, *, interpolation_method = "linear"))]
fn eval_robust_semantics(expr: &PyBoolExpr, trace: &PyTrace, interpolation_method: &str) -> PyResult<Py<FloatSignal>> {
let interp = PyInterp::from_str(interpolation_method)?;
let sig = match interp {
PyInterp::Linear => QuantitativeSemantics::eval::<Linear>(&expr.0, trace).map_err(PyArgusError::from)?,
PyInterp::Constant => QuantitativeSemantics::eval::<Constant>(&expr.0, trace).map_err(PyArgusError::from)?,
};
Python::with_gil(|py| Py::new(py, (FloatSignal, PySignal::new(sig, interp))))
}
pub fn init(_py: Python, m: &PyModule) -> PyResult<()> {

View file

@ -1,3 +1,5 @@
use std::str::FromStr;
use argus::signals::interpolation::{Constant, Linear};
use argus::signals::Signal;
use pyo3::exceptions::PyValueError;
@ -14,6 +16,21 @@ pub enum PyInterp {
Constant,
}
impl FromStr for PyInterp {
type Err = PyErr;
fn from_str(method: &str) -> Result<Self, Self::Err> {
match method {
"linear" => Ok(PyInterp::Linear),
"constant" => Ok(PyInterp::Constant),
_ => Err(PyValueError::new_err(format!(
"unsupported interpolation method `{}`",
method
))),
}
}
}
#[derive(Debug, Clone, derive_more::From, derive_more::TryInto)]
#[try_into(owned, ref, ref_mut)]
pub enum SignalKind {
@ -126,11 +143,7 @@ macro_rules! impl_signals {
#[new]
#[pyo3(signature = (*, interpolation_method = "linear"))]
fn new(interpolation_method: &str) -> PyResult<(Self, PySignal)> {
let interp = match interpolation_method {
"linear" => PyInterp::Linear,
"constant" => PyInterp::Constant,
_ => return Err(PyValueError::new_err(format!("unsupported interpolation method `{}`", interpolation_method))),
};
let interp = PyInterp::from_str(interpolation_method)?;
Ok((Self, PySignal::new(Signal::<$ty>::new(), interp)))
}
@ -138,11 +151,7 @@ macro_rules! impl_signals {
#[classmethod]
#[pyo3(signature = (value, *, interpolation_method = "linear"))]
fn constant(_: &PyType, py: Python<'_>, value: $ty, interpolation_method: &str) -> PyResult<Py<Self>> {
let interp = match interpolation_method {
"linear" => PyInterp::Linear,
"constant" => PyInterp::Constant,
_ => return Err(PyValueError::new_err(format!("unsupported interpolation method `{}`", interpolation_method))),
};
let interp = PyInterp::from_str(interpolation_method)?;
Py::new(
py,
(Self, PySignal::new(Signal::constant(value), interp))
@ -158,11 +167,7 @@ macro_rules! impl_signals {
.map(|(t, v)| (core::time::Duration::try_from_secs_f64(t).unwrap_or_else(|err| panic!("Value = {}, {}", t, err)), v))
).map_err(PyArgusError::from)?;
let interp = match interpolation_method {
"linear" => PyInterp::Linear,
"constant" => PyInterp::Constant,
_ => return Err(PyValueError::new_err(format!("unsupported interpolation method `{}`", interpolation_method))),
};
let interp = PyInterp::from_str(interpolation_method)?;
Python::with_gil(|py| {
Py::new(
py,