refactor(pyargus): data type name

This commit is contained in:
Anand Balakrishnan 2023-09-07 15:43:04 -07:00
parent 8093ab7c9f
commit e2cfe3da56
No known key found for this signature in database
5 changed files with 85 additions and 81 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import List, Optional, Tuple, Type, Union from typing import List, Optional, Tuple, Type, Union
from argus import _argus from argus import _argus
from argus._argus import DType as DType from argus._argus import dtype
from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt
from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
@ -15,25 +15,19 @@ except AttributeError:
AllowedDtype = Union[bool, int, float] AllowedDtype = Union[bool, int, float]
def declare_var(name: str, dtype: Union[DType, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]: def declare_var(name: str, dtype_: Union[dtype, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]:
"""Declare a variable with the given name and type""" """Declare a variable with the given name and type"""
if isinstance(dtype, type): dtype_ = dtype.convert(dtype_)
if dtype == bool:
dtype = DType.Bool
elif dtype == int:
dtype = DType.Int
elif dtype == float:
dtype = DType.Float
if dtype == DType.Bool: if dtype_ == dtype.bool_:
return VarBool(name) return VarBool(name)
elif dtype == DType.Int: elif dtype_ == dtype.int64:
return VarInt(name) return VarInt(name)
elif dtype == DType.UnsignedInt: elif dtype_ == dtype.uint64:
return VarUInt(name) return VarUInt(name)
elif dtype == DType.Float: elif dtype_ == dtype.float64:
return VarFloat(name) return VarFloat(name)
raise TypeError(f"unsupported variable type `{dtype}`") raise TypeError(f"unsupported variable type `{dtype_}`")
def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstFloat]: def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstFloat]:
@ -48,7 +42,7 @@ def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstF
def signal( def signal(
dtype: Union[DType, Type[AllowedDtype]], dtype_: Union[dtype, Type[AllowedDtype]],
*, *,
data: Optional[Union[AllowedDtype, List[Tuple[float, AllowedDtype]]]] = None, data: Optional[Union[AllowedDtype, List[Tuple[float, AllowedDtype]]]] = None,
) -> Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]: ) -> Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]:
@ -57,7 +51,7 @@ def signal(
Parameters Parameters
---------- ----------
dtype: dtype_:
Type of the signal Type of the signal
data : data :
@ -67,21 +61,21 @@ def signal(
factory: Type[Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]] factory: Type[Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]]
expected_type: Type[AllowedDtype] expected_type: Type[AllowedDtype]
dtype = DType.convert(dtype) dtype_ = dtype.convert(dtype_)
if dtype == DType.Bool: if dtype_ == dtype.bool_:
factory = BoolSignal factory = BoolSignal
expected_type = bool expected_type = bool
elif dtype == DType.UnsignedInt: elif dtype_ == dtype.uint64:
factory = UnsignedIntSignal factory = UnsignedIntSignal
expected_type = int expected_type = int
elif dtype == DType.Int: elif dtype_ == dtype.int64:
factory = IntSignal factory = IntSignal
expected_type = int expected_type = int
elif dtype == DType.Float: elif dtype_ == dtype.float64:
factory = FloatSignal factory = FloatSignal
expected_type = float expected_type = float
else: else:
raise ValueError(f"unsupported dtype {dtype}") raise ValueError(f"unsupported dtype_ {dtype}")
if data is None: if data is None:
return factory.from_samples([]) return factory.from_samples([])
@ -92,7 +86,7 @@ def signal(
__all__ = [ __all__ = [
"DType", "dtype",
"declare_var", "declare_var",
"literal", "literal",
"signal", "signal",

View file

@ -123,11 +123,11 @@ class Until(BoolExpr):
def __init__(self, lhs: BoolExpr, rhs: BoolExpr) -> None: ... def __init__(self, lhs: BoolExpr, rhs: BoolExpr) -> None: ...
@final @final
class DType: class dtype: # noqa: N801
Bool: ClassVar[DType] = ... bool_: ClassVar[dtype] = ...
Float: ClassVar[DType] = ... float64: ClassVar[dtype] = ...
Int: ClassVar[DType] = ... int64: ClassVar[dtype] = ...
UnsignedInt: ClassVar[DType] = ... uint64: ClassVar[dtype] = ...
@classmethod @classmethod
def convert(cls, dtype: type[bool | int | float] | Self) -> Self: ... # noqa: Y041 def convert(cls, dtype: type[bool | int | float] | Self) -> Self: ... # noqa: Y041
@ -143,7 +143,7 @@ class Signal(Generic[_SignalKind], Protocol):
@property @property
def end_time(self) -> float | None: ... def end_time(self) -> float | None: ...
@property @property
def kind(self) -> type[bool | int | float]: ... def kind(self) -> dtype: ...
@final @final
class BoolSignal(Signal[bool]): class BoolSignal(Signal[bool]):

View file

@ -28,12 +28,16 @@ impl From<PyArgusError> for PyErr {
} }
} }
#[pyclass(module = "argus")] #[pyclass(module = "argus", name = "dtype")]
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub enum DType { pub enum DType {
#[pyo3(name = "bool_")]
Bool, Bool,
#[pyo3(name = "int64")]
Int, Int,
#[pyo3(name = "uint64")]
UnsignedInt, UnsignedInt,
#[pyo3(name = "float64")]
Float, Float,
} }

View file

@ -1,9 +1,9 @@
use argus_core::signals::interpolation::Linear; use argus_core::signals::interpolation::Linear;
use argus_core::signals::Signal; use argus_core::signals::Signal;
use pyo3::prelude::*; use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyType}; use pyo3::types::PyType;
use crate::PyArgusError; use crate::{DType, PyArgusError};
#[pyclass(name = "InterpolationMethod", module = "argus")] #[pyclass(name = "InterpolationMethod", module = "argus")]
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
@ -21,22 +21,30 @@ pub enum SignalKind {
Float(Signal<f64>), Float(Signal<f64>),
} }
impl SignalKind {
/// Get the kind of the signal
pub fn kind(&self) -> DType {
match self {
SignalKind::Bool(_) => DType::Bool,
SignalKind::Int(_) => DType::Int,
SignalKind::UnsignedInt(_) => DType::UnsignedInt,
SignalKind::Float(_) => DType::Float,
}
}
}
#[pyclass(name = "Signal", subclass, module = "argus")] #[pyclass(name = "Signal", subclass, module = "argus")]
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct PySignal { pub struct PySignal {
pub interpolation: PyInterp, pub(crate) interpolation: PyInterp,
pub signal: SignalKind, pub(crate) signal: SignalKind,
} }
#[pymethods] #[pymethods]
impl PySignal { impl PySignal {
#[getter] #[getter]
fn kind<'py>(&self, py: Python<'py>) -> &'py PyType { fn kind(&self) -> DType {
match self.signal { self.signal.kind()
SignalKind::Bool(_) => PyType::new::<PyBool>(py),
SignalKind::Int(_) | SignalKind::UnsignedInt(_) => PyType::new::<PyInt>(py),
SignalKind::Float(_) => PyType::new::<PyFloat>(py),
}
} }
fn __repr__(&self) -> String { fn __repr__(&self) -> String {

View file

@ -6,22 +6,20 @@ from hypothesis import strategies as st
from hypothesis.strategies import SearchStrategy, composite from hypothesis.strategies import SearchStrategy, composite
import argus import argus
from argus import DType from argus import AllowedDtype, dtype
AllowedDtype = Union[bool, int, float]
def gen_element_fn(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[AllowedDtype]: def gen_element_fn(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[AllowedDtype]:
new_dtype = DType.convert(dtype) new_dtype = dtype.convert(dtype_)
if new_dtype == DType.Bool: if new_dtype == dtype.bool_:
return st.booleans() return st.booleans()
elif new_dtype == DType.Int: elif new_dtype == dtype.int64:
size = 2**64 size = 2**64
return st.integers(min_value=(-size // 2), max_value=((size - 1) // 2)) return st.integers(min_value=(-size // 2), max_value=((size - 1) // 2))
elif new_dtype == DType.UnsignedInt: elif new_dtype == dtype.uint64:
size = 2**64 size = 2**64
return st.integers(min_value=0, max_value=(size - 1)) return st.integers(min_value=0, max_value=(size - 1))
elif new_dtype == DType.Float: elif new_dtype == dtype.float64:
return st.floats( return st.floats(
width=64, width=64,
allow_nan=False, allow_nan=False,
@ -29,18 +27,18 @@ def gen_element_fn(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[Al
allow_subnormal=False, allow_subnormal=False,
) )
else: else:
raise ValueError(f"invalid dtype {dtype}") raise ValueError(f"invalid dtype {dtype_}")
@composite @composite
def gen_samples( def gen_samples(
draw: st.DrawFn, *, min_size: int, max_size: int, dtype: Union[Type[AllowedDtype], DType] draw: st.DrawFn, min_size: int, max_size: int, dtype_: Union[Type[AllowedDtype], dtype]
) -> List[Tuple[float, AllowedDtype]]: ) -> List[Tuple[float, AllowedDtype]]:
""" """
Generate arbitrary samples for a signal where the time stamps are strictly Generate arbitrary samples for a signal where the time stamps are strictly
monotonically increasing monotonically increasing
""" """
elements = gen_element_fn(dtype) elements = gen_element_fn(dtype_)
values = draw(st.lists(elements, min_size=min_size, max_size=max_size)) values = draw(st.lists(elements, min_size=min_size, max_size=max_size))
xs = draw( xs = draw(
st.lists( st.lists(
@ -55,28 +53,28 @@ def gen_samples(
return xs return xs
def empty_signal(*, dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]: def empty_signal(*, dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
new_dtype: DType = DType.convert(dtype) new_dtype: dtype = dtype.convert(dtype_)
sig: argus.Signal sig: argus.Signal
if new_dtype == DType.Bool: if new_dtype == dtype.bool_:
sig = argus.BoolSignal() sig = argus.BoolSignal()
assert sig.kind is bool assert sig.kind == dtype.bool_
elif new_dtype == DType.UnsignedInt: elif new_dtype == dtype.uint64:
sig = argus.UnsignedIntSignal() sig = argus.UnsignedIntSignal()
assert sig.kind is int assert sig.kind == dtype.uint64
elif new_dtype == DType.Int: elif new_dtype == dtype.int64:
sig = argus.IntSignal() sig = argus.IntSignal()
assert sig.kind is int assert sig.kind == dtype.int64
elif new_dtype == DType.Float: elif new_dtype == dtype.float64:
sig = argus.FloatSignal() sig = argus.FloatSignal()
assert sig.kind is float assert sig.kind == dtype.float64
else: else:
raise ValueError("unknown dtype") raise ValueError("unknown dtype")
return st.just(sig) return st.just(sig)
def constant_signal(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]: def constant_signal(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
return gen_element_fn(dtype).map(lambda val: argus.signal(dtype, data=val)) return gen_element_fn(dtype_).map(lambda val: argus.signal(dtype_, data=val))
@composite @composite
@ -87,16 +85,16 @@ def draw_index(draw: st.DrawFn, vec: List) -> int:
return draw(st.just(0)) return draw(st.just(0))
def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], DType]]: def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], dtype]]:
return st.one_of( return st.one_of(
list(map(st.just, [DType.Bool, DType.UnsignedInt, DType.Int, DType.Float, bool, int, float])), # type: ignore[arg-type] list(map(st.just, [dtype.bool_, dtype.uint64, dtype.int64, dtype.float64, bool, int, float])), # type: ignore[arg-type]
) )
@given(st.data()) @given(st.data())
def test_correct_constant_signals(data: st.DataObject) -> None: def test_correct_constant_signals(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype()) dtype_ = data.draw(gen_dtype())
signal = data.draw(constant_signal(dtype)) signal = data.draw(constant_signal(dtype_))
assert not signal.is_empty() assert not signal.is_empty()
assert signal.start_time is None assert signal.start_time is None
@ -105,11 +103,11 @@ def test_correct_constant_signals(data: st.DataObject) -> None:
@given(st.data()) @given(st.data())
def test_correctly_create_signals(data: st.DataObject) -> None: def test_correctly_create_signals(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype()) dtype_ = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=0, max_size=100, dtype=dtype)) xs = data.draw(gen_samples(min_size=0, max_size=100, dtype_=dtype_))
note(f"Samples: {gen_samples}") note(f"Samples: {gen_samples}")
signal = argus.signal(dtype, data=xs) signal = argus.signal(dtype_, data=xs)
if len(xs) > 0: if len(xs) > 0:
expected_start_time = xs[0][0] expected_start_time = xs[0][0]
expected_end_time = xs[-1][0] expected_end_time = xs[-1][0]
@ -132,7 +130,7 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
# generate one more sample # generate one more sample
new_time = actual_end_time + 1 new_time = actual_end_time + 1
new_value = data.draw(gen_element_fn(dtype)) new_value = data.draw(gen_element_fn(dtype_))
signal.push(new_time, new_value) # type: ignore[arg-type] signal.push(new_time, new_value) # type: ignore[arg-type]
get_val = signal.at(new_time) get_val = signal.at(new_time)
@ -148,8 +146,8 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
@given(st.data()) @given(st.data())
def test_signal_create_should_fail(data: st.DataObject) -> None: def test_signal_create_should_fail(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype()) dtype_ = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype=dtype)) xs = data.draw(gen_samples(min_size=10, max_size=100, dtype_=dtype_))
a = data.draw(draw_index(xs)) a = data.draw(draw_index(xs))
b = data.draw(draw_index(xs)) b = data.draw(draw_index(xs))
assume(a != b) assume(a != b)
@ -161,24 +159,24 @@ def test_signal_create_should_fail(data: st.DataObject) -> None:
xs[b], xs[a] = xs[a], xs[b] xs[b], xs[a] = xs[a], xs[b]
with pytest.raises(RuntimeError, match=r"trying to create a non-monotonically signal.+"): with pytest.raises(RuntimeError, match=r"trying to create a non-monotonically signal.+"):
_ = argus.signal(dtype, data=xs) _ = argus.signal(dtype_, data=xs)
@given(st.data()) @given(st.data())
def test_push_to_empty_signal(data: st.DataObject) -> None: def test_push_to_empty_signal(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype()) dtype_ = data.draw(gen_dtype())
sig = data.draw(empty_signal(dtype=dtype)) sig = data.draw(empty_signal(dtype_=dtype_))
assert sig.is_empty() assert sig.is_empty()
element = data.draw(gen_element_fn(dtype)) element = data.draw(gen_element_fn(dtype_))
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"): with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
sig.push(0.0, element) # type: ignore[attr-defined] sig.push(0.0, element) # type: ignore[attr-defined]
@given(st.data()) @given(st.data())
def test_push_to_constant_signal(data: st.DataObject) -> None: def test_push_to_constant_signal(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype()) dtype_ = data.draw(gen_dtype())
sig = data.draw(constant_signal(dtype=dtype)) sig = data.draw(constant_signal(dtype_=dtype_))
assert not sig.is_empty() assert not sig.is_empty()
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype=dtype))[0] sample = data.draw(gen_samples(min_size=1, max_size=1, dtype_=dtype_))[0]
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"): with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
sig.push(*sample) # type: ignore[attr-defined] sig.push(*sample) # type: ignore[attr-defined]