fix(pyargus): address typing issues

This addresses some of the issues with inheritance (internal to the rust
module) for signals, and generally making mypy and flake8 happy.
This commit is contained in:
Anand Balakrishnan 2023-09-01 14:52:35 -07:00
parent ccd87fc22a
commit a25e56f025
No known key found for this signature in database
6 changed files with 192 additions and 136 deletions

View file

@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Type, Union
from argus import _argus from argus import _argus
from argus._argus import DType as DType from argus._argus import DType as DType
from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt
from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal from argus.signals import BoolSignal, FloatSignal, IntSignal, UnsignedIntSignal
try: try:
__doc__ = _argus.__doc__ __doc__ = _argus.__doc__
@ -51,26 +51,23 @@ def signal(
dtype: Union[DType, Type[AllowedDtype]], dtype: Union[DType, Type[AllowedDtype]],
*, *,
data: Optional[Union[AllowedDtype, List[Tuple[float, AllowedDtype]]]] = None, data: Optional[Union[AllowedDtype, List[Tuple[float, AllowedDtype]]]] = None,
) -> Signal: ) -> Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]:
"""Create a signal of the given type """Create a signal of the given type
Parameters Parameters
---------- ----------
dtype:
Type of the signal
data : data :
If a constant scalar is given, a constant signal is created. Otherwise, if a list of sample points are given, a sampled If a constant scalar is given, a constant signal is created. Otherwise, if a list of sample points are given, a sampled
signal is constructed. Otherwise, an empty signal is created. signal is constructed. Otherwise, an empty signal is created.
""" """
if isinstance(dtype, type):
if dtype == bool:
dtype = DType.Bool
elif dtype == int:
dtype = DType.Int
elif dtype == float:
dtype = DType.Float
factory: Type[Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]] factory: Type[Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]]
expected_type: Type[AllowedDtype] expected_type: Type[AllowedDtype]
dtype = DType.convert(dtype)
if dtype == DType.Bool: if dtype == DType.Bool:
factory = BoolSignal factory = BoolSignal
expected_type = bool expected_type = bool
@ -89,10 +86,9 @@ def signal(
if data is None: if data is None:
return factory.from_samples([]) return factory.from_samples([])
elif isinstance(data, (list, tuple)): elif isinstance(data, (list, tuple)):
return factory.from_samples(data) # type: ignore return factory.from_samples(data) # type: ignore[arg-type]
else: assert isinstance(data, expected_type)
assert isinstance(data, expected_type) return factory.constant(data) # type: ignore[arg-type]
return factory.constant(data) # type: ignore
__all__ = [ __all__ = [

View file

@ -1,8 +1,8 @@
from abc import ABC from typing import ClassVar, Protocol, TypeVar, final
from enum import Enum, auto
from typing import List, Optional, Tuple, final
class NumExpr(ABC): from typing_extensions import Self
class NumExpr(Protocol):
def __ge__(self, other) -> NumExpr: ... def __ge__(self, other) -> NumExpr: ...
def __gt__(self, other) -> NumExpr: ... def __gt__(self, other) -> NumExpr: ...
def __le__(self, other) -> NumExpr: ... def __le__(self, other) -> NumExpr: ...
@ -50,11 +50,11 @@ class Negate(NumExpr):
@final @final
class Add(NumExpr): class Add(NumExpr):
def __init__(self, args: List[NumExpr]): ... def __init__(self, args: list[NumExpr]): ...
@final @final
class Mul(NumExpr): class Mul(NumExpr):
def __init__(self, args: List[NumExpr]): ... def __init__(self, args: list[NumExpr]): ...
@final @final
class Div(NumExpr): class Div(NumExpr):
@ -64,7 +64,7 @@ class Div(NumExpr):
class Abs(NumExpr): class Abs(NumExpr):
def __init__(self, arg: NumExpr): ... def __init__(self, arg: NumExpr): ...
class BoolExpr(ABC): class BoolExpr(Protocol):
def __and__(self, other) -> BoolExpr: ... def __and__(self, other) -> BoolExpr: ...
def __invert__(self) -> BoolExpr: ... def __invert__(self) -> BoolExpr: ...
def __or__(self, other) -> BoolExpr: ... def __or__(self, other) -> BoolExpr: ...
@ -100,11 +100,11 @@ class Not(BoolExpr):
@final @final
class And(BoolExpr): class And(BoolExpr):
def __init__(self, args: List[BoolExpr]): ... def __init__(self, args: list[BoolExpr]): ...
@final @final
class Or(BoolExpr): class Or(BoolExpr):
def __init__(self, args: List[BoolExpr]): ... def __init__(self, args: list[BoolExpr]): ...
@final @final
class Next(BoolExpr): class Next(BoolExpr):
@ -123,57 +123,63 @@ class Until(BoolExpr):
def __init__(self, lhs: BoolExpr, rhs: BoolExpr): ... def __init__(self, lhs: BoolExpr, rhs: BoolExpr): ...
@final @final
class DType(Enum): class DType:
Bool = auto() Bool: ClassVar[DType] = ...
Int = auto() Float: ClassVar[DType] = ...
UnsignedInt = auto() Int: ClassVar[DType] = ...
Float = auto() UnsignedInt: ClassVar[DType] = ...
class Signal(ABC): ... @classmethod
def convert(cls, dtype: type[bool | int | float] | Self) -> Self: ... # noqa: Y041
def __eq__(self, other) -> bool: ...
def __int__(self) -> int: ...
_SignalKind = TypeVar("_SignalKind", bool, int, float, covariant=True)
class Signal(Protocol[_SignalKind]):
def is_empty(self) -> bool: ...
@property
def start_time(self) -> float | None: ...
@property
def end_time(self) -> float | None: ...
@property
def kind(self) -> type[bool | int | float]: ...
@final @final
class BoolSignal(Signal): class BoolSignal(Signal[bool]):
def __init__(self): ... @classmethod
@staticmethod def constant(cls, value: bool) -> Self: ...
def constant(value: bool) -> BoolSignal: ... @classmethod
@staticmethod def from_samples(cls, samples: list[tuple[float, bool]]) -> Self: ...
def from_samples(samples: List[Tuple[float, bool]]) -> BoolSignal: ...
def push(self, time: float, value: bool): ... def push(self, time: float, value: bool): ...
def is_empty(self) -> bool: ... def at(self, time: float) -> _SignalKind | None: ...
def at(self, time: float) -> Optional[bool]: ...
@final @final
class IntSignal(Signal): class IntSignal(Signal[int]):
def __init__(self): ... @classmethod
@staticmethod def constant(cls, value: int) -> Self: ...
def constant(value: int) -> IntSignal: ... @classmethod
@staticmethod def from_samples(cls, samples: list[tuple[float, int]]) -> Self: ...
def from_samples(samples: List[Tuple[float, int]]) -> IntSignal: ...
def push(self, time: float, value: int): ... def push(self, time: float, value: int): ...
def is_empty(self) -> bool: ... def at(self, time: float) -> int | None: ...
def at(self, time: float) -> Optional[int]: ...
@final @final
class UnsignedIntSignal(Signal): class UnsignedIntSignal(Signal[int]):
def __init__(self): ... @classmethod
@staticmethod def constant(cls, value: int) -> Self: ...
def constant(value: int) -> UnsignedIntSignal: ... @classmethod
@staticmethod def from_samples(cls, samples: list[tuple[float, int]]) -> Self: ...
def from_samples(samples: List[Tuple[float, int]]) -> UnsignedIntSignal: ...
def push(self, time: float, value: int): ... def push(self, time: float, value: int): ...
def is_empty(self) -> bool: ... def at(self, time: float) -> int | None: ...
def at(self, time: float) -> Optional[int]: ...
@final @final
class FloatSignal(Signal): class FloatSignal(Signal[float]):
def __init__(self): ... @classmethod
@staticmethod def constant(cls, value: float) -> Self: ...
def constant(value: float) -> FloatSignal: ... @classmethod
@staticmethod def from_samples(cls, samples: list[tuple[float, float]]) -> Self: ...
def from_samples(samples: List[Tuple[float, float]]) -> FloatSignal: ...
def push(self, time: float, value: float): ... def push(self, time: float, value: float): ...
def is_empty(self) -> bool: ... def at(self, time: float) -> float | None: ...
def at(self, time: float) -> Optional[float]: ...
@final @final
class Trace: ... class Trace: ...

View file

@ -5,6 +5,7 @@ mod signals;
use argus_core::ArgusError; use argus_core::ArgusError;
use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyTypeError, PyValueError}; use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyTypeError, PyValueError};
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyType};
#[derive(derive_more::From)] #[derive(derive_more::From)]
struct PyArgusError(ArgusError); struct PyArgusError(ArgusError);
@ -27,7 +28,7 @@ impl From<PyArgusError> for PyErr {
} }
} }
#[pyclass] #[pyclass(module = "argus")]
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub enum DType { pub enum DType {
Bool, Bool,
@ -36,6 +37,33 @@ pub enum DType {
Float, Float,
} }
#[pymethods]
impl DType {
#[classmethod]
fn convert(_: &PyType, dtype: &PyAny, py: Python<'_>) -> PyResult<Self> {
use DType::*;
if dtype.is_instance_of::<DType>() {
dtype.extract::<DType>()
} else if dtype.is_instance_of::<PyType>() {
let dtype = dtype.downcast_exact::<PyType>()?;
if dtype.is(PyType::new::<PyBool>(py)) {
Ok(Bool)
} else if dtype.is(PyType::new::<PyInt>(py)) {
Ok(Int)
} else if dtype.is(PyType::new::<PyFloat>(py)) {
Ok(Float)
} else {
Err(PyTypeError::new_err(format!("unsupported type {}", dtype)))
}
} else {
Err(PyTypeError::new_err(format!(
"unsupported dtype {}, expected a `type`",
dtype
)))
}
}
}
#[pymodule] #[pymodule]
#[pyo3(name = "_argus")] #[pyo3(name = "_argus")]
fn pyargus(py: Python, m: &PyModule) -> PyResult<()> { fn pyargus(py: Python, m: &PyModule) -> PyResult<()> {

View file

@ -7,17 +7,8 @@ use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString}; use pyo3::types::{PyDict, PyString};
use crate::expr::PyBoolExpr; use crate::expr::PyBoolExpr;
use crate::signals::{BoolSignal, FloatSignal, IntSignal, PySignal, UnsignedIntSignal}; use crate::signals::{BoolSignal, FloatSignal, PySignal, SignalKind};
use crate::{DType, PyArgusError}; use crate::PyArgusError;
#[derive(Debug, Clone, derive_more::From, derive_more::TryInto)]
#[try_into(owned, ref, ref_mut)]
enum SignalKind {
Bool(Signal<bool>),
Int(Signal<i64>),
UnsignedInt(Signal<u64>),
Float(Signal<f64>),
}
#[pyclass(name = "Trace", module = "argus")] #[pyclass(name = "Trace", module = "argus")]
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
@ -40,20 +31,7 @@ impl PyTrace {
key, e key, e
)) ))
})?; })?;
let kind = val.borrow().kind; let signal = val.borrow().signal.clone();
let signal: SignalKind = match kind {
DType::Bool => val.downcast::<PyCell<BoolSignal>>().unwrap().borrow().0.clone().into(),
DType::Int => val.downcast::<PyCell<IntSignal>>().unwrap().borrow().0.clone().into(),
DType::UnsignedInt => val
.downcast::<PyCell<UnsignedIntSignal>>()
.unwrap()
.borrow()
.0
.clone()
.into(),
DType::Float => val.downcast::<PyCell<FloatSignal>>().unwrap().borrow().0.clone().into(),
};
signals.insert(key.to_string(), signal); signals.insert(key.to_string(), signal);
} }
@ -85,12 +63,12 @@ impl Trace for PyTrace {
#[pyfunction] #[pyfunction]
fn eval_bool_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<BoolSignal>> { fn eval_bool_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<BoolSignal>> {
let sig = BooleanSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?; let sig = BooleanSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?;
Python::with_gil(|py| Py::new(py, (BoolSignal::from(sig), BoolSignal::super_type()))) Python::with_gil(|py| Py::new(py, (BoolSignal, BoolSignal::super_type(sig.into()))))
} }
#[pyfunction] #[pyfunction]
fn eval_robust_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<FloatSignal>> { fn eval_robust_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<FloatSignal>> {
let sig = QuantitativeSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?; let sig = QuantitativeSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?;
Python::with_gil(|py| Py::new(py, (FloatSignal::from(sig), FloatSignal::super_type()))) Python::with_gil(|py| Py::new(py, (FloatSignal, FloatSignal::super_type(sig.into()))))
} }
pub fn init(_py: Python, m: &PyModule) -> PyResult<()> { pub fn init(_py: Python, m: &PyModule) -> PyResult<()> {

View file

@ -3,8 +3,9 @@ use std::time::Duration;
use argus_core::signals::interpolation::Linear; use argus_core::signals::interpolation::Linear;
use argus_core::signals::Signal; use argus_core::signals::Signal;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyType};
use crate::{DType, PyArgusError}; use crate::PyArgusError;
#[pyclass(name = "InterpolationMethod", module = "argus")] #[pyclass(name = "InterpolationMethod", module = "argus")]
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
@ -13,59 +14,128 @@ pub enum PyInterp {
Linear, Linear,
} }
#[derive(Debug, Clone, derive_more::From, derive_more::TryInto)]
#[try_into(owned, ref, ref_mut)]
pub enum SignalKind {
Bool(Signal<bool>),
Int(Signal<i64>),
UnsignedInt(Signal<u64>),
Float(Signal<f64>),
}
#[pyclass(name = "Signal", subclass, module = "argus")] #[pyclass(name = "Signal", subclass, module = "argus")]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PySignal { pub struct PySignal {
pub kind: DType,
pub interpolation: PyInterp, pub interpolation: PyInterp,
pub signal: SignalKind,
}
#[pymethods]
impl PySignal {
#[getter]
fn kind<'py>(&self, py: Python<'py>) -> &'py PyType {
match self.signal {
SignalKind::Bool(_) => PyType::new::<PyBool>(py),
SignalKind::Int(_) | SignalKind::UnsignedInt(_) => PyType::new::<PyInt>(py),
SignalKind::Float(_) => PyType::new::<PyFloat>(py),
}
}
fn __repr__(&self) -> String {
match &self.signal {
SignalKind::Bool(sig) => format!("Signal::<{}>::{:?}", "bool", sig),
SignalKind::Int(sig) => format!("Signal::<{}>::{:?}", "i64", sig),
SignalKind::UnsignedInt(sig) => format!("Signal::<{}>::{:?}", "u64", sig),
SignalKind::Float(sig) => format!("Signal::<{}>::{:?}", "f64", sig),
}
}
/// Check if the signal is empty
fn is_empty(&self) -> bool {
match &self.signal {
SignalKind::Bool(sig) => sig.is_empty(),
SignalKind::Int(sig) => sig.is_empty(),
SignalKind::UnsignedInt(sig) => sig.is_empty(),
SignalKind::Float(sig) => sig.is_empty(),
}
}
/// The start time of the signal
#[getter]
fn start_time(&self) -> Option<f64> {
use core::ops::Bound::*;
let start_time = match &self.signal {
SignalKind::Bool(sig) => sig.start_time()?,
SignalKind::Int(sig) => sig.start_time()?,
SignalKind::UnsignedInt(sig) => sig.start_time()?,
SignalKind::Float(sig) => sig.start_time()?,
};
match start_time {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
}
/// The end time of the signal
#[getter]
fn end_time(&self) -> Option<f64> {
use core::ops::Bound::*;
let end_time = match &self.signal {
SignalKind::Bool(sig) => sig.end_time()?,
SignalKind::Int(sig) => sig.end_time()?,
SignalKind::UnsignedInt(sig) => sig.end_time()?,
SignalKind::Float(sig) => sig.end_time()?,
};
match end_time {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
}
} }
macro_rules! impl_signals { macro_rules! impl_signals {
($ty_name:ident, $ty:ty) => { ($ty_name:ident, $ty:ty) => {
paste::paste! { paste::paste! {
#[pyclass(extends=PySignal, module = "argus")] #[pyclass(extends=PySignal, module = "argus")]
#[derive(Debug, Clone, derive_more::From)] #[derive(Debug, Copy, Clone)]
pub struct [<$ty_name Signal>](pub Signal<$ty>); pub struct [<$ty_name Signal>];
impl [<$ty_name Signal>] { impl [<$ty_name Signal>] {
#[inline] #[inline]
pub fn super_type() -> PySignal { pub fn super_type(signal: SignalKind) -> PySignal {
PySignal { PySignal {
interpolation: PyInterp::Linear, interpolation: PyInterp::Linear,
kind: DType::$ty_name, signal,
} }
} }
} }
#[pymethods] #[pymethods]
impl [<$ty_name Signal>] { impl [<$ty_name Signal>] {
fn __repr__(&self) -> String {
format!("Signal::<{}>::{:?}", stringify!($ty), self.0)
}
/// Create a new empty signal /// Create a new empty signal
#[new] #[new]
#[pyo3(signature = ())] #[pyo3(signature = ())]
fn new() -> (Self, PySignal) { fn new() -> (Self, PySignal) {
(Self(Signal::new()), Self::super_type()) (Self, Self::super_type(Signal::<$ty>::new().into()))
} }
#[pyo3(signature = ())]
fn __init__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> { fn __init__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
self_ self_
} }
/// Create a new signal with constant value /// Create a new signal with constant value
#[staticmethod] #[classmethod]
fn constant(py: Python<'_>, value: $ty) -> PyResult<Py<Self>> { fn constant(_: &PyType, py: Python<'_>, value: $ty) -> PyResult<Py<Self>> {
Py::new( Py::new(
py, py,
(Self(Signal::constant(value)), Self::super_type()) (Self, Self::super_type(Signal::constant(value).into()))
) )
} }
/// Create a new signal from some finite number of samples /// Create a new signal from some finite number of samples
#[staticmethod] #[classmethod]
fn from_samples(samples: Vec<(f64, $ty)>) -> PyResult<Py<Self>> { fn from_samples(_: &PyType, samples: Vec<(f64, $ty)>) -> PyResult<Py<Self>> {
let ret: Signal<$ty> = samples let ret: Signal<$ty> = samples
.into_iter() .into_iter()
.map(|(t, v)| (Duration::from_secs_f64(t), v)) .map(|(t, v)| (Duration::from_secs_f64(t), v))
@ -73,43 +143,20 @@ macro_rules! impl_signals {
Python::with_gil(|py| { Python::with_gil(|py| {
Py::new( Py::new(
py, py,
(Self(ret), Self::super_type()) (Self, Self::super_type(ret.into()))
) )
}) })
} }
/// Push a new sample into the given signal. /// Push a new sample into the given signal.
#[pyo3(signature = (time, value))] #[pyo3(signature = (time, value))]
fn push(&mut self, time: f64, value: $ty) -> Result<(), PyArgusError> { fn push(mut self_: PyRefMut<'_, Self>, time: f64, value: $ty) -> Result<(), PyArgusError> {
self.0.push(Duration::from_secs_f64(time), value)?; let super_: &mut PySignal = self_.as_mut();
let signal: &mut Signal<$ty> = (&mut super_.signal).try_into().unwrap();
signal.push(Duration::from_secs_f64(time), value)?;
Ok(()) Ok(())
} }
/// Check if the signal is empty
fn is_empty(&self) -> bool {
self.0.is_empty()
}
/// The start time of the signal
#[getter]
fn start_time(&self) -> Option<f64> {
use core::ops::Bound::*;
match self.0.start_time()? {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
}
/// The end time of the signal
#[getter]
fn end_time(&self) -> Option<f64> {
use core::ops::Bound::*;
match self.0.end_time()? {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
}
/// Get the value of the signal at the given time point. /// Get the value of the signal at the given time point.
/// ///
/// If there exists a sample, then the value is returned, otherwise the value is /// If there exists a sample, then the value is returned, otherwise the value is
@ -117,9 +164,10 @@ macro_rules! impl_signals {
/// is returned. /// is returned.
fn at(self_: PyRef<'_, Self>, time: f64) -> Option<$ty> { fn at(self_: PyRef<'_, Self>, time: f64) -> Option<$ty> {
let super_ = self_.as_ref(); let super_ = self_.as_ref();
let signal: &Signal<$ty> = (&super_.signal).try_into().unwrap();
let time = core::time::Duration::from_secs_f64(time); let time = core::time::Duration::from_secs_f64(time);
match super_.interpolation { match super_.interpolation {
PyInterp::Linear => self_.0.interpolate_at::<Linear>(time), PyInterp::Linear => signal.interpolate_at::<Linear>(time),
} }
} }

View file

@ -66,5 +66,5 @@ def test_correctly_create_signals(data: Tuple[List[Tuple[float, AllowedDtype]],
assert a < len(samples) assert a < len(samples)
assert b < len(samples) assert b < len(samples)
else: else:
assert signal.is_empty() # type: ignore[attr-defined] assert signal.is_empty()
assert signal.at(0) is None # type: ignore[attr-defined] assert signal.at(0) is None