fix(pyargus): address typing issues

This addresses some of the issues with inheritance (internal to the rust
module) for signals, and generally making mypy and flake8 happy.
This commit is contained in:
Anand Balakrishnan 2023-09-01 14:52:35 -07:00
parent ccd87fc22a
commit a25e56f025
No known key found for this signature in database
6 changed files with 192 additions and 136 deletions

View file

@ -5,6 +5,7 @@ mod signals;
use argus_core::ArgusError;
use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyType};
#[derive(derive_more::From)]
struct PyArgusError(ArgusError);
@ -27,7 +28,7 @@ impl From<PyArgusError> for PyErr {
}
}
#[pyclass]
#[pyclass(module = "argus")]
#[derive(Copy, Clone, Debug)]
pub enum DType {
Bool,
@ -36,6 +37,33 @@ pub enum DType {
Float,
}
#[pymethods]
impl DType {
#[classmethod]
fn convert(_: &PyType, dtype: &PyAny, py: Python<'_>) -> PyResult<Self> {
use DType::*;
if dtype.is_instance_of::<DType>() {
dtype.extract::<DType>()
} else if dtype.is_instance_of::<PyType>() {
let dtype = dtype.downcast_exact::<PyType>()?;
if dtype.is(PyType::new::<PyBool>(py)) {
Ok(Bool)
} else if dtype.is(PyType::new::<PyInt>(py)) {
Ok(Int)
} else if dtype.is(PyType::new::<PyFloat>(py)) {
Ok(Float)
} else {
Err(PyTypeError::new_err(format!("unsupported type {}", dtype)))
}
} else {
Err(PyTypeError::new_err(format!(
"unsupported dtype {}, expected a `type`",
dtype
)))
}
}
}
#[pymodule]
#[pyo3(name = "_argus")]
fn pyargus(py: Python, m: &PyModule) -> PyResult<()> {

View file

@ -7,17 +7,8 @@ use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};
use crate::expr::PyBoolExpr;
use crate::signals::{BoolSignal, FloatSignal, IntSignal, PySignal, UnsignedIntSignal};
use crate::{DType, PyArgusError};
#[derive(Debug, Clone, derive_more::From, derive_more::TryInto)]
#[try_into(owned, ref, ref_mut)]
enum SignalKind {
Bool(Signal<bool>),
Int(Signal<i64>),
UnsignedInt(Signal<u64>),
Float(Signal<f64>),
}
use crate::signals::{BoolSignal, FloatSignal, PySignal, SignalKind};
use crate::PyArgusError;
#[pyclass(name = "Trace", module = "argus")]
#[derive(Debug, Clone, Default)]
@ -40,20 +31,7 @@ impl PyTrace {
key, e
))
})?;
let kind = val.borrow().kind;
let signal: SignalKind = match kind {
DType::Bool => val.downcast::<PyCell<BoolSignal>>().unwrap().borrow().0.clone().into(),
DType::Int => val.downcast::<PyCell<IntSignal>>().unwrap().borrow().0.clone().into(),
DType::UnsignedInt => val
.downcast::<PyCell<UnsignedIntSignal>>()
.unwrap()
.borrow()
.0
.clone()
.into(),
DType::Float => val.downcast::<PyCell<FloatSignal>>().unwrap().borrow().0.clone().into(),
};
let signal = val.borrow().signal.clone();
signals.insert(key.to_string(), signal);
}
@ -85,12 +63,12 @@ impl Trace for PyTrace {
#[pyfunction]
fn eval_bool_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<BoolSignal>> {
let sig = BooleanSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?;
Python::with_gil(|py| Py::new(py, (BoolSignal::from(sig), BoolSignal::super_type())))
Python::with_gil(|py| Py::new(py, (BoolSignal, BoolSignal::super_type(sig.into()))))
}
#[pyfunction]
fn eval_robust_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<FloatSignal>> {
let sig = QuantitativeSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?;
Python::with_gil(|py| Py::new(py, (FloatSignal::from(sig), FloatSignal::super_type())))
Python::with_gil(|py| Py::new(py, (FloatSignal, FloatSignal::super_type(sig.into()))))
}
pub fn init(_py: Python, m: &PyModule) -> PyResult<()> {

View file

@ -3,8 +3,9 @@ use std::time::Duration;
use argus_core::signals::interpolation::Linear;
use argus_core::signals::Signal;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyType};
use crate::{DType, PyArgusError};
use crate::PyArgusError;
#[pyclass(name = "InterpolationMethod", module = "argus")]
#[derive(Debug, Clone, Copy, Default)]
@ -13,59 +14,128 @@ pub enum PyInterp {
Linear,
}
#[derive(Debug, Clone, derive_more::From, derive_more::TryInto)]
#[try_into(owned, ref, ref_mut)]
pub enum SignalKind {
Bool(Signal<bool>),
Int(Signal<i64>),
UnsignedInt(Signal<u64>),
Float(Signal<f64>),
}
#[pyclass(name = "Signal", subclass, module = "argus")]
#[derive(Debug, Clone)]
pub struct PySignal {
pub kind: DType,
pub interpolation: PyInterp,
pub signal: SignalKind,
}
#[pymethods]
impl PySignal {
#[getter]
fn kind<'py>(&self, py: Python<'py>) -> &'py PyType {
match self.signal {
SignalKind::Bool(_) => PyType::new::<PyBool>(py),
SignalKind::Int(_) | SignalKind::UnsignedInt(_) => PyType::new::<PyInt>(py),
SignalKind::Float(_) => PyType::new::<PyFloat>(py),
}
}
fn __repr__(&self) -> String {
match &self.signal {
SignalKind::Bool(sig) => format!("Signal::<{}>::{:?}", "bool", sig),
SignalKind::Int(sig) => format!("Signal::<{}>::{:?}", "i64", sig),
SignalKind::UnsignedInt(sig) => format!("Signal::<{}>::{:?}", "u64", sig),
SignalKind::Float(sig) => format!("Signal::<{}>::{:?}", "f64", sig),
}
}
/// Check if the signal is empty
fn is_empty(&self) -> bool {
match &self.signal {
SignalKind::Bool(sig) => sig.is_empty(),
SignalKind::Int(sig) => sig.is_empty(),
SignalKind::UnsignedInt(sig) => sig.is_empty(),
SignalKind::Float(sig) => sig.is_empty(),
}
}
/// The start time of the signal
#[getter]
fn start_time(&self) -> Option<f64> {
use core::ops::Bound::*;
let start_time = match &self.signal {
SignalKind::Bool(sig) => sig.start_time()?,
SignalKind::Int(sig) => sig.start_time()?,
SignalKind::UnsignedInt(sig) => sig.start_time()?,
SignalKind::Float(sig) => sig.start_time()?,
};
match start_time {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
}
/// The end time of the signal
#[getter]
fn end_time(&self) -> Option<f64> {
use core::ops::Bound::*;
let end_time = match &self.signal {
SignalKind::Bool(sig) => sig.end_time()?,
SignalKind::Int(sig) => sig.end_time()?,
SignalKind::UnsignedInt(sig) => sig.end_time()?,
SignalKind::Float(sig) => sig.end_time()?,
};
match end_time {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
}
}
macro_rules! impl_signals {
($ty_name:ident, $ty:ty) => {
paste::paste! {
#[pyclass(extends=PySignal, module = "argus")]
#[derive(Debug, Clone, derive_more::From)]
pub struct [<$ty_name Signal>](pub Signal<$ty>);
#[derive(Debug, Copy, Clone)]
pub struct [<$ty_name Signal>];
impl [<$ty_name Signal>] {
#[inline]
pub fn super_type() -> PySignal {
pub fn super_type(signal: SignalKind) -> PySignal {
PySignal {
interpolation: PyInterp::Linear,
kind: DType::$ty_name,
signal,
}
}
}
#[pymethods]
impl [<$ty_name Signal>] {
fn __repr__(&self) -> String {
format!("Signal::<{}>::{:?}", stringify!($ty), self.0)
}
/// Create a new empty signal
#[new]
#[pyo3(signature = ())]
fn new() -> (Self, PySignal) {
(Self(Signal::new()), Self::super_type())
(Self, Self::super_type(Signal::<$ty>::new().into()))
}
#[pyo3(signature = ())]
fn __init__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
self_
}
/// Create a new signal with constant value
#[staticmethod]
fn constant(py: Python<'_>, value: $ty) -> PyResult<Py<Self>> {
#[classmethod]
fn constant(_: &PyType, py: Python<'_>, value: $ty) -> PyResult<Py<Self>> {
Py::new(
py,
(Self(Signal::constant(value)), Self::super_type())
(Self, Self::super_type(Signal::constant(value).into()))
)
}
/// Create a new signal from some finite number of samples
#[staticmethod]
fn from_samples(samples: Vec<(f64, $ty)>) -> PyResult<Py<Self>> {
#[classmethod]
fn from_samples(_: &PyType, samples: Vec<(f64, $ty)>) -> PyResult<Py<Self>> {
let ret: Signal<$ty> = samples
.into_iter()
.map(|(t, v)| (Duration::from_secs_f64(t), v))
@ -73,43 +143,20 @@ macro_rules! impl_signals {
Python::with_gil(|py| {
Py::new(
py,
(Self(ret), Self::super_type())
(Self, Self::super_type(ret.into()))
)
})
}
/// Push a new sample into the given signal.
#[pyo3(signature = (time, value))]
fn push(&mut self, time: f64, value: $ty) -> Result<(), PyArgusError> {
self.0.push(Duration::from_secs_f64(time), value)?;
fn push(mut self_: PyRefMut<'_, Self>, time: f64, value: $ty) -> Result<(), PyArgusError> {
let super_: &mut PySignal = self_.as_mut();
let signal: &mut Signal<$ty> = (&mut super_.signal).try_into().unwrap();
signal.push(Duration::from_secs_f64(time), value)?;
Ok(())
}
/// Check if the signal is empty
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// The start time of the signal
#[getter]
fn start_time(&self) -> Option<f64> {
use core::ops::Bound::*;
match self.0.start_time()? {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
}
/// The end time of the signal
#[getter]
fn end_time(&self) -> Option<f64> {
use core::ops::Bound::*;
match self.0.end_time()? {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
}
/// Get the value of the signal at the given time point.
///
/// If there exists a sample, then the value is returned, otherwise the value is
@ -117,9 +164,10 @@ macro_rules! impl_signals {
/// is returned.
fn at(self_: PyRef<'_, Self>, time: f64) -> Option<$ty> {
let super_ = self_.as_ref();
let signal: &Signal<$ty> = (&super_.signal).try_into().unwrap();
let time = core::time::Duration::from_secs_f64(time);
match super_.interpolation {
PyInterp::Linear => self_.0.interpolate_at::<Linear>(time),
PyInterp::Linear => signal.interpolate_at::<Linear>(time),
}
}