fix(pyargus): address typing issues

This addresses some of the issues with inheritance (internal to the rust
module) for signals, and generally making mypy and flake8 happy.
This commit is contained in:
Anand Balakrishnan 2023-09-01 14:52:35 -07:00
parent ccd87fc22a
commit a25e56f025
No known key found for this signature in database
6 changed files with 192 additions and 136 deletions

View file

@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Type, Union
from argus import _argus
from argus._argus import DType as DType
from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt
from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
from argus.signals import BoolSignal, FloatSignal, IntSignal, UnsignedIntSignal
try:
__doc__ = _argus.__doc__
@ -51,26 +51,23 @@ def signal(
dtype: Union[DType, Type[AllowedDtype]],
*,
data: Optional[Union[AllowedDtype, List[Tuple[float, AllowedDtype]]]] = None,
) -> Signal:
) -> Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]:
"""Create a signal of the given type
Parameters
----------
dtype:
Type of the signal
data :
If a constant scalar is given, a constant signal is created. Otherwise, if a list of sample points are given, a sampled
signal is constructed. Otherwise, an empty signal is created.
"""
if isinstance(dtype, type):
if dtype == bool:
dtype = DType.Bool
elif dtype == int:
dtype = DType.Int
elif dtype == float:
dtype = DType.Float
factory: Type[Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]]
expected_type: Type[AllowedDtype]
dtype = DType.convert(dtype)
if dtype == DType.Bool:
factory = BoolSignal
expected_type = bool
@ -89,10 +86,9 @@ def signal(
if data is None:
return factory.from_samples([])
elif isinstance(data, (list, tuple)):
return factory.from_samples(data) # type: ignore
else:
return factory.from_samples(data) # type: ignore[arg-type]
assert isinstance(data, expected_type)
return factory.constant(data) # type: ignore
return factory.constant(data) # type: ignore[arg-type]
__all__ = [

View file

@ -1,8 +1,8 @@
from abc import ABC
from enum import Enum, auto
from typing import List, Optional, Tuple, final
from typing import ClassVar, Protocol, TypeVar, final
class NumExpr(ABC):
from typing_extensions import Self
class NumExpr(Protocol):
def __ge__(self, other) -> NumExpr: ...
def __gt__(self, other) -> NumExpr: ...
def __le__(self, other) -> NumExpr: ...
@ -50,11 +50,11 @@ class Negate(NumExpr):
@final
class Add(NumExpr):
def __init__(self, args: List[NumExpr]): ...
def __init__(self, args: list[NumExpr]): ...
@final
class Mul(NumExpr):
def __init__(self, args: List[NumExpr]): ...
def __init__(self, args: list[NumExpr]): ...
@final
class Div(NumExpr):
@ -64,7 +64,7 @@ class Div(NumExpr):
class Abs(NumExpr):
def __init__(self, arg: NumExpr): ...
class BoolExpr(ABC):
class BoolExpr(Protocol):
def __and__(self, other) -> BoolExpr: ...
def __invert__(self) -> BoolExpr: ...
def __or__(self, other) -> BoolExpr: ...
@ -100,11 +100,11 @@ class Not(BoolExpr):
@final
class And(BoolExpr):
def __init__(self, args: List[BoolExpr]): ...
def __init__(self, args: list[BoolExpr]): ...
@final
class Or(BoolExpr):
def __init__(self, args: List[BoolExpr]): ...
def __init__(self, args: list[BoolExpr]): ...
@final
class Next(BoolExpr):
@ -123,57 +123,63 @@ class Until(BoolExpr):
def __init__(self, lhs: BoolExpr, rhs: BoolExpr): ...
@final
class DType(Enum):
Bool = auto()
Int = auto()
UnsignedInt = auto()
Float = auto()
class DType:
Bool: ClassVar[DType] = ...
Float: ClassVar[DType] = ...
Int: ClassVar[DType] = ...
UnsignedInt: ClassVar[DType] = ...
class Signal(ABC): ...
@classmethod
def convert(cls, dtype: type[bool | int | float] | Self) -> Self: ... # noqa: Y041
def __eq__(self, other) -> bool: ...
def __int__(self) -> int: ...
_SignalKind = TypeVar("_SignalKind", bool, int, float, covariant=True)
class Signal(Protocol[_SignalKind]):
def is_empty(self) -> bool: ...
@property
def start_time(self) -> float | None: ...
@property
def end_time(self) -> float | None: ...
@property
def kind(self) -> type[bool | int | float]: ...
@final
class BoolSignal(Signal):
def __init__(self): ...
@staticmethod
def constant(value: bool) -> BoolSignal: ...
@staticmethod
def from_samples(samples: List[Tuple[float, bool]]) -> BoolSignal: ...
class BoolSignal(Signal[bool]):
@classmethod
def constant(cls, value: bool) -> Self: ...
@classmethod
def from_samples(cls, samples: list[tuple[float, bool]]) -> Self: ...
def push(self, time: float, value: bool): ...
def is_empty(self) -> bool: ...
def at(self, time: float) -> Optional[bool]: ...
def at(self, time: float) -> _SignalKind | None: ...
@final
class IntSignal(Signal):
def __init__(self): ...
@staticmethod
def constant(value: int) -> IntSignal: ...
@staticmethod
def from_samples(samples: List[Tuple[float, int]]) -> IntSignal: ...
class IntSignal(Signal[int]):
@classmethod
def constant(cls, value: int) -> Self: ...
@classmethod
def from_samples(cls, samples: list[tuple[float, int]]) -> Self: ...
def push(self, time: float, value: int): ...
def is_empty(self) -> bool: ...
def at(self, time: float) -> Optional[int]: ...
def at(self, time: float) -> int | None: ...
@final
class UnsignedIntSignal(Signal):
def __init__(self): ...
@staticmethod
def constant(value: int) -> UnsignedIntSignal: ...
@staticmethod
def from_samples(samples: List[Tuple[float, int]]) -> UnsignedIntSignal: ...
class UnsignedIntSignal(Signal[int]):
@classmethod
def constant(cls, value: int) -> Self: ...
@classmethod
def from_samples(cls, samples: list[tuple[float, int]]) -> Self: ...
def push(self, time: float, value: int): ...
def is_empty(self) -> bool: ...
def at(self, time: float) -> Optional[int]: ...
def at(self, time: float) -> int | None: ...
@final
class FloatSignal(Signal):
def __init__(self): ...
@staticmethod
def constant(value: float) -> FloatSignal: ...
@staticmethod
def from_samples(samples: List[Tuple[float, float]]) -> FloatSignal: ...
class FloatSignal(Signal[float]):
@classmethod
def constant(cls, value: float) -> Self: ...
@classmethod
def from_samples(cls, samples: list[tuple[float, float]]) -> Self: ...
def push(self, time: float, value: float): ...
def is_empty(self) -> bool: ...
def at(self, time: float) -> Optional[float]: ...
def at(self, time: float) -> float | None: ...
@final
class Trace: ...

View file

@ -5,6 +5,7 @@ mod signals;
use argus_core::ArgusError;
use pyo3::exceptions::{PyKeyError, PyRuntimeError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyType};
#[derive(derive_more::From)]
struct PyArgusError(ArgusError);
@ -27,7 +28,7 @@ impl From<PyArgusError> for PyErr {
}
}
#[pyclass]
#[pyclass(module = "argus")]
#[derive(Copy, Clone, Debug)]
pub enum DType {
Bool,
@ -36,6 +37,33 @@ pub enum DType {
Float,
}
#[pymethods]
impl DType {
#[classmethod]
fn convert(_: &PyType, dtype: &PyAny, py: Python<'_>) -> PyResult<Self> {
use DType::*;
if dtype.is_instance_of::<DType>() {
dtype.extract::<DType>()
} else if dtype.is_instance_of::<PyType>() {
let dtype = dtype.downcast_exact::<PyType>()?;
if dtype.is(PyType::new::<PyBool>(py)) {
Ok(Bool)
} else if dtype.is(PyType::new::<PyInt>(py)) {
Ok(Int)
} else if dtype.is(PyType::new::<PyFloat>(py)) {
Ok(Float)
} else {
Err(PyTypeError::new_err(format!("unsupported type {}", dtype)))
}
} else {
Err(PyTypeError::new_err(format!(
"unsupported dtype {}, expected a `type`",
dtype
)))
}
}
}
#[pymodule]
#[pyo3(name = "_argus")]
fn pyargus(py: Python, m: &PyModule) -> PyResult<()> {

View file

@ -7,17 +7,8 @@ use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};
use crate::expr::PyBoolExpr;
use crate::signals::{BoolSignal, FloatSignal, IntSignal, PySignal, UnsignedIntSignal};
use crate::{DType, PyArgusError};
#[derive(Debug, Clone, derive_more::From, derive_more::TryInto)]
#[try_into(owned, ref, ref_mut)]
enum SignalKind {
Bool(Signal<bool>),
Int(Signal<i64>),
UnsignedInt(Signal<u64>),
Float(Signal<f64>),
}
use crate::signals::{BoolSignal, FloatSignal, PySignal, SignalKind};
use crate::PyArgusError;
#[pyclass(name = "Trace", module = "argus")]
#[derive(Debug, Clone, Default)]
@ -40,20 +31,7 @@ impl PyTrace {
key, e
))
})?;
let kind = val.borrow().kind;
let signal: SignalKind = match kind {
DType::Bool => val.downcast::<PyCell<BoolSignal>>().unwrap().borrow().0.clone().into(),
DType::Int => val.downcast::<PyCell<IntSignal>>().unwrap().borrow().0.clone().into(),
DType::UnsignedInt => val
.downcast::<PyCell<UnsignedIntSignal>>()
.unwrap()
.borrow()
.0
.clone()
.into(),
DType::Float => val.downcast::<PyCell<FloatSignal>>().unwrap().borrow().0.clone().into(),
};
let signal = val.borrow().signal.clone();
signals.insert(key.to_string(), signal);
}
@ -85,12 +63,12 @@ impl Trace for PyTrace {
#[pyfunction]
fn eval_bool_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<BoolSignal>> {
let sig = BooleanSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?;
Python::with_gil(|py| Py::new(py, (BoolSignal::from(sig), BoolSignal::super_type())))
Python::with_gil(|py| Py::new(py, (BoolSignal, BoolSignal::super_type(sig.into()))))
}
#[pyfunction]
fn eval_robust_semantics(expr: &PyBoolExpr, trace: &PyTrace) -> PyResult<Py<FloatSignal>> {
let sig = QuantitativeSemantics::eval(&expr.0, trace).map_err(PyArgusError::from)?;
Python::with_gil(|py| Py::new(py, (FloatSignal::from(sig), FloatSignal::super_type())))
Python::with_gil(|py| Py::new(py, (FloatSignal, FloatSignal::super_type(sig.into()))))
}
pub fn init(_py: Python, m: &PyModule) -> PyResult<()> {

View file

@ -3,8 +3,9 @@ use std::time::Duration;
use argus_core::signals::interpolation::Linear;
use argus_core::signals::Signal;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyType};
use crate::{DType, PyArgusError};
use crate::PyArgusError;
#[pyclass(name = "InterpolationMethod", module = "argus")]
#[derive(Debug, Clone, Copy, Default)]
@ -13,88 +14,63 @@ pub enum PyInterp {
Linear,
}
#[derive(Debug, Clone, derive_more::From, derive_more::TryInto)]
#[try_into(owned, ref, ref_mut)]
pub enum SignalKind {
Bool(Signal<bool>),
Int(Signal<i64>),
UnsignedInt(Signal<u64>),
Float(Signal<f64>),
}
#[pyclass(name = "Signal", subclass, module = "argus")]
#[derive(Debug, Clone)]
pub struct PySignal {
pub kind: DType,
pub interpolation: PyInterp,
pub signal: SignalKind,
}
macro_rules! impl_signals {
($ty_name:ident, $ty:ty) => {
paste::paste! {
#[pyclass(extends=PySignal, module = "argus")]
#[derive(Debug, Clone, derive_more::From)]
pub struct [<$ty_name Signal>](pub Signal<$ty>);
impl [<$ty_name Signal>] {
#[inline]
pub fn super_type() -> PySignal {
PySignal {
interpolation: PyInterp::Linear,
kind: DType::$ty_name,
}
#[pymethods]
impl PySignal {
#[getter]
fn kind<'py>(&self, py: Python<'py>) -> &'py PyType {
match self.signal {
SignalKind::Bool(_) => PyType::new::<PyBool>(py),
SignalKind::Int(_) | SignalKind::UnsignedInt(_) => PyType::new::<PyInt>(py),
SignalKind::Float(_) => PyType::new::<PyFloat>(py),
}
}
#[pymethods]
impl [<$ty_name Signal>] {
fn __repr__(&self) -> String {
format!("Signal::<{}>::{:?}", stringify!($ty), self.0)
match &self.signal {
SignalKind::Bool(sig) => format!("Signal::<{}>::{:?}", "bool", sig),
SignalKind::Int(sig) => format!("Signal::<{}>::{:?}", "i64", sig),
SignalKind::UnsignedInt(sig) => format!("Signal::<{}>::{:?}", "u64", sig),
SignalKind::Float(sig) => format!("Signal::<{}>::{:?}", "f64", sig),
}
/// Create a new empty signal
#[new]
#[pyo3(signature = ())]
fn new() -> (Self, PySignal) {
(Self(Signal::new()), Self::super_type())
}
fn __init__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
self_
}
/// Create a new signal with constant value
#[staticmethod]
fn constant(py: Python<'_>, value: $ty) -> PyResult<Py<Self>> {
Py::new(
py,
(Self(Signal::constant(value)), Self::super_type())
)
}
/// Create a new signal from some finite number of samples
#[staticmethod]
fn from_samples(samples: Vec<(f64, $ty)>) -> PyResult<Py<Self>> {
let ret: Signal<$ty> = samples
.into_iter()
.map(|(t, v)| (Duration::from_secs_f64(t), v))
.collect();
Python::with_gil(|py| {
Py::new(
py,
(Self(ret), Self::super_type())
)
})
}
/// Push a new sample into the given signal.
#[pyo3(signature = (time, value))]
fn push(&mut self, time: f64, value: $ty) -> Result<(), PyArgusError> {
self.0.push(Duration::from_secs_f64(time), value)?;
Ok(())
}
/// Check if the signal is empty
fn is_empty(&self) -> bool {
self.0.is_empty()
match &self.signal {
SignalKind::Bool(sig) => sig.is_empty(),
SignalKind::Int(sig) => sig.is_empty(),
SignalKind::UnsignedInt(sig) => sig.is_empty(),
SignalKind::Float(sig) => sig.is_empty(),
}
}
/// The start time of the signal
#[getter]
fn start_time(&self) -> Option<f64> {
use core::ops::Bound::*;
match self.0.start_time()? {
let start_time = match &self.signal {
SignalKind::Bool(sig) => sig.start_time()?,
SignalKind::Int(sig) => sig.start_time()?,
SignalKind::UnsignedInt(sig) => sig.start_time()?,
SignalKind::Float(sig) => sig.start_time()?,
};
match start_time {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
@ -104,11 +80,82 @@ macro_rules! impl_signals {
#[getter]
fn end_time(&self) -> Option<f64> {
use core::ops::Bound::*;
match self.0.end_time()? {
let end_time = match &self.signal {
SignalKind::Bool(sig) => sig.end_time()?,
SignalKind::Int(sig) => sig.end_time()?,
SignalKind::UnsignedInt(sig) => sig.end_time()?,
SignalKind::Float(sig) => sig.end_time()?,
};
match end_time {
Included(t) | Excluded(t) => Some(t.as_secs_f64()),
_ => None,
}
}
}
macro_rules! impl_signals {
($ty_name:ident, $ty:ty) => {
paste::paste! {
#[pyclass(extends=PySignal, module = "argus")]
#[derive(Debug, Copy, Clone)]
pub struct [<$ty_name Signal>];
impl [<$ty_name Signal>] {
#[inline]
pub fn super_type(signal: SignalKind) -> PySignal {
PySignal {
interpolation: PyInterp::Linear,
signal,
}
}
}
#[pymethods]
impl [<$ty_name Signal>] {
/// Create a new empty signal
#[new]
#[pyo3(signature = ())]
fn new() -> (Self, PySignal) {
(Self, Self::super_type(Signal::<$ty>::new().into()))
}
#[pyo3(signature = ())]
fn __init__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
self_
}
/// Create a new signal with constant value
#[classmethod]
fn constant(_: &PyType, py: Python<'_>, value: $ty) -> PyResult<Py<Self>> {
Py::new(
py,
(Self, Self::super_type(Signal::constant(value).into()))
)
}
/// Create a new signal from some finite number of samples
#[classmethod]
fn from_samples(_: &PyType, samples: Vec<(f64, $ty)>) -> PyResult<Py<Self>> {
let ret: Signal<$ty> = samples
.into_iter()
.map(|(t, v)| (Duration::from_secs_f64(t), v))
.collect();
Python::with_gil(|py| {
Py::new(
py,
(Self, Self::super_type(ret.into()))
)
})
}
/// Push a new sample into the given signal.
#[pyo3(signature = (time, value))]
fn push(mut self_: PyRefMut<'_, Self>, time: f64, value: $ty) -> Result<(), PyArgusError> {
let super_: &mut PySignal = self_.as_mut();
let signal: &mut Signal<$ty> = (&mut super_.signal).try_into().unwrap();
signal.push(Duration::from_secs_f64(time), value)?;
Ok(())
}
/// Get the value of the signal at the given time point.
///
@ -117,9 +164,10 @@ macro_rules! impl_signals {
/// is returned.
fn at(self_: PyRef<'_, Self>, time: f64) -> Option<$ty> {
let super_ = self_.as_ref();
let signal: &Signal<$ty> = (&super_.signal).try_into().unwrap();
let time = core::time::Duration::from_secs_f64(time);
match super_.interpolation {
PyInterp::Linear => self_.0.interpolate_at::<Linear>(time),
PyInterp::Linear => signal.interpolate_at::<Linear>(time),
}
}

View file

@ -66,5 +66,5 @@ def test_correctly_create_signals(data: Tuple[List[Tuple[float, AllowedDtype]],
assert a < len(samples)
assert b < len(samples)
else:
assert signal.is_empty() # type: ignore[attr-defined]
assert signal.at(0) is None # type: ignore[attr-defined]
assert signal.is_empty()
assert signal.at(0) is None