refactor(pyargus): data type name
This commit is contained in:
parent
8093ab7c9f
commit
e2cfe3da56
5 changed files with 85 additions and 81 deletions
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
from argus import _argus
|
||||
from argus._argus import DType as DType
|
||||
from argus._argus import dtype
|
||||
from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt
|
||||
from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
|
||||
|
||||
|
|
@ -15,25 +15,19 @@ except AttributeError:
|
|||
AllowedDtype = Union[bool, int, float]
|
||||
|
||||
|
||||
def declare_var(name: str, dtype: Union[DType, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]:
|
||||
def declare_var(name: str, dtype_: Union[dtype, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]:
|
||||
"""Declare a variable with the given name and type"""
|
||||
if isinstance(dtype, type):
|
||||
if dtype == bool:
|
||||
dtype = DType.Bool
|
||||
elif dtype == int:
|
||||
dtype = DType.Int
|
||||
elif dtype == float:
|
||||
dtype = DType.Float
|
||||
dtype_ = dtype.convert(dtype_)
|
||||
|
||||
if dtype == DType.Bool:
|
||||
if dtype_ == dtype.bool_:
|
||||
return VarBool(name)
|
||||
elif dtype == DType.Int:
|
||||
elif dtype_ == dtype.int64:
|
||||
return VarInt(name)
|
||||
elif dtype == DType.UnsignedInt:
|
||||
elif dtype_ == dtype.uint64:
|
||||
return VarUInt(name)
|
||||
elif dtype == DType.Float:
|
||||
elif dtype_ == dtype.float64:
|
||||
return VarFloat(name)
|
||||
raise TypeError(f"unsupported variable type `{dtype}`")
|
||||
raise TypeError(f"unsupported variable type `{dtype_}`")
|
||||
|
||||
|
||||
def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstFloat]:
|
||||
|
|
@ -48,7 +42,7 @@ def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstF
|
|||
|
||||
|
||||
def signal(
|
||||
dtype: Union[DType, Type[AllowedDtype]],
|
||||
dtype_: Union[dtype, Type[AllowedDtype]],
|
||||
*,
|
||||
data: Optional[Union[AllowedDtype, List[Tuple[float, AllowedDtype]]]] = None,
|
||||
) -> Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]:
|
||||
|
|
@ -57,7 +51,7 @@ def signal(
|
|||
Parameters
|
||||
----------
|
||||
|
||||
dtype:
|
||||
dtype_:
|
||||
Type of the signal
|
||||
|
||||
data :
|
||||
|
|
@ -67,21 +61,21 @@ def signal(
|
|||
factory: Type[Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]]
|
||||
expected_type: Type[AllowedDtype]
|
||||
|
||||
dtype = DType.convert(dtype)
|
||||
if dtype == DType.Bool:
|
||||
dtype_ = dtype.convert(dtype_)
|
||||
if dtype_ == dtype.bool_:
|
||||
factory = BoolSignal
|
||||
expected_type = bool
|
||||
elif dtype == DType.UnsignedInt:
|
||||
elif dtype_ == dtype.uint64:
|
||||
factory = UnsignedIntSignal
|
||||
expected_type = int
|
||||
elif dtype == DType.Int:
|
||||
elif dtype_ == dtype.int64:
|
||||
factory = IntSignal
|
||||
expected_type = int
|
||||
elif dtype == DType.Float:
|
||||
elif dtype_ == dtype.float64:
|
||||
factory = FloatSignal
|
||||
expected_type = float
|
||||
else:
|
||||
raise ValueError(f"unsupported dtype {dtype}")
|
||||
raise ValueError(f"unsupported dtype_ {dtype}")
|
||||
|
||||
if data is None:
|
||||
return factory.from_samples([])
|
||||
|
|
@ -92,7 +86,7 @@ def signal(
|
|||
|
||||
|
||||
__all__ = [
|
||||
"DType",
|
||||
"dtype",
|
||||
"declare_var",
|
||||
"literal",
|
||||
"signal",
|
||||
|
|
|
|||
|
|
@ -123,11 +123,11 @@ class Until(BoolExpr):
|
|||
def __init__(self, lhs: BoolExpr, rhs: BoolExpr) -> None: ...
|
||||
|
||||
@final
|
||||
class DType:
|
||||
Bool: ClassVar[DType] = ...
|
||||
Float: ClassVar[DType] = ...
|
||||
Int: ClassVar[DType] = ...
|
||||
UnsignedInt: ClassVar[DType] = ...
|
||||
class dtype: # noqa: N801
|
||||
bool_: ClassVar[dtype] = ...
|
||||
float64: ClassVar[dtype] = ...
|
||||
int64: ClassVar[dtype] = ...
|
||||
uint64: ClassVar[dtype] = ...
|
||||
|
||||
@classmethod
|
||||
def convert(cls, dtype: type[bool | int | float] | Self) -> Self: ... # noqa: Y041
|
||||
|
|
@ -143,7 +143,7 @@ class Signal(Generic[_SignalKind], Protocol):
|
|||
@property
|
||||
def end_time(self) -> float | None: ...
|
||||
@property
|
||||
def kind(self) -> type[bool | int | float]: ...
|
||||
def kind(self) -> dtype: ...
|
||||
|
||||
@final
|
||||
class BoolSignal(Signal[bool]):
|
||||
|
|
|
|||
|
|
@ -28,12 +28,16 @@ impl From<PyArgusError> for PyErr {
|
|||
}
|
||||
}
|
||||
|
||||
#[pyclass(module = "argus")]
|
||||
#[pyclass(module = "argus", name = "dtype")]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum DType {
|
||||
#[pyo3(name = "bool_")]
|
||||
Bool,
|
||||
#[pyo3(name = "int64")]
|
||||
Int,
|
||||
#[pyo3(name = "uint64")]
|
||||
UnsignedInt,
|
||||
#[pyo3(name = "float64")]
|
||||
Float,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
use argus_core::signals::interpolation::Linear;
|
||||
use argus_core::signals::Signal;
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::types::{PyBool, PyFloat, PyInt, PyType};
|
||||
use pyo3::types::PyType;
|
||||
|
||||
use crate::PyArgusError;
|
||||
use crate::{DType, PyArgusError};
|
||||
|
||||
#[pyclass(name = "InterpolationMethod", module = "argus")]
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
|
|
@ -21,22 +21,30 @@ pub enum SignalKind {
|
|||
Float(Signal<f64>),
|
||||
}
|
||||
|
||||
impl SignalKind {
|
||||
/// Get the kind of the signal
|
||||
pub fn kind(&self) -> DType {
|
||||
match self {
|
||||
SignalKind::Bool(_) => DType::Bool,
|
||||
SignalKind::Int(_) => DType::Int,
|
||||
SignalKind::UnsignedInt(_) => DType::UnsignedInt,
|
||||
SignalKind::Float(_) => DType::Float,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass(name = "Signal", subclass, module = "argus")]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PySignal {
|
||||
pub interpolation: PyInterp,
|
||||
pub signal: SignalKind,
|
||||
pub(crate) interpolation: PyInterp,
|
||||
pub(crate) signal: SignalKind,
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl PySignal {
|
||||
#[getter]
|
||||
fn kind<'py>(&self, py: Python<'py>) -> &'py PyType {
|
||||
match self.signal {
|
||||
SignalKind::Bool(_) => PyType::new::<PyBool>(py),
|
||||
SignalKind::Int(_) | SignalKind::UnsignedInt(_) => PyType::new::<PyInt>(py),
|
||||
SignalKind::Float(_) => PyType::new::<PyFloat>(py),
|
||||
}
|
||||
fn kind(&self) -> DType {
|
||||
self.signal.kind()
|
||||
}
|
||||
|
||||
fn __repr__(&self) -> String {
|
||||
|
|
|
|||
|
|
@ -6,22 +6,20 @@ from hypothesis import strategies as st
|
|||
from hypothesis.strategies import SearchStrategy, composite
|
||||
|
||||
import argus
|
||||
from argus import DType
|
||||
|
||||
AllowedDtype = Union[bool, int, float]
|
||||
from argus import AllowedDtype, dtype
|
||||
|
||||
|
||||
def gen_element_fn(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[AllowedDtype]:
|
||||
new_dtype = DType.convert(dtype)
|
||||
if new_dtype == DType.Bool:
|
||||
def gen_element_fn(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[AllowedDtype]:
|
||||
new_dtype = dtype.convert(dtype_)
|
||||
if new_dtype == dtype.bool_:
|
||||
return st.booleans()
|
||||
elif new_dtype == DType.Int:
|
||||
elif new_dtype == dtype.int64:
|
||||
size = 2**64
|
||||
return st.integers(min_value=(-size // 2), max_value=((size - 1) // 2))
|
||||
elif new_dtype == DType.UnsignedInt:
|
||||
elif new_dtype == dtype.uint64:
|
||||
size = 2**64
|
||||
return st.integers(min_value=0, max_value=(size - 1))
|
||||
elif new_dtype == DType.Float:
|
||||
elif new_dtype == dtype.float64:
|
||||
return st.floats(
|
||||
width=64,
|
||||
allow_nan=False,
|
||||
|
|
@ -29,18 +27,18 @@ def gen_element_fn(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[Al
|
|||
allow_subnormal=False,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"invalid dtype {dtype}")
|
||||
raise ValueError(f"invalid dtype {dtype_}")
|
||||
|
||||
|
||||
@composite
|
||||
def gen_samples(
|
||||
draw: st.DrawFn, *, min_size: int, max_size: int, dtype: Union[Type[AllowedDtype], DType]
|
||||
draw: st.DrawFn, min_size: int, max_size: int, dtype_: Union[Type[AllowedDtype], dtype]
|
||||
) -> List[Tuple[float, AllowedDtype]]:
|
||||
"""
|
||||
Generate arbitrary samples for a signal where the time stamps are strictly
|
||||
monotonically increasing
|
||||
"""
|
||||
elements = gen_element_fn(dtype)
|
||||
elements = gen_element_fn(dtype_)
|
||||
values = draw(st.lists(elements, min_size=min_size, max_size=max_size))
|
||||
xs = draw(
|
||||
st.lists(
|
||||
|
|
@ -55,28 +53,28 @@ def gen_samples(
|
|||
return xs
|
||||
|
||||
|
||||
def empty_signal(*, dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]:
|
||||
new_dtype: DType = DType.convert(dtype)
|
||||
def empty_signal(*, dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
|
||||
new_dtype: dtype = dtype.convert(dtype_)
|
||||
sig: argus.Signal
|
||||
if new_dtype == DType.Bool:
|
||||
if new_dtype == dtype.bool_:
|
||||
sig = argus.BoolSignal()
|
||||
assert sig.kind is bool
|
||||
elif new_dtype == DType.UnsignedInt:
|
||||
assert sig.kind == dtype.bool_
|
||||
elif new_dtype == dtype.uint64:
|
||||
sig = argus.UnsignedIntSignal()
|
||||
assert sig.kind is int
|
||||
elif new_dtype == DType.Int:
|
||||
assert sig.kind == dtype.uint64
|
||||
elif new_dtype == dtype.int64:
|
||||
sig = argus.IntSignal()
|
||||
assert sig.kind is int
|
||||
elif new_dtype == DType.Float:
|
||||
assert sig.kind == dtype.int64
|
||||
elif new_dtype == dtype.float64:
|
||||
sig = argus.FloatSignal()
|
||||
assert sig.kind is float
|
||||
assert sig.kind == dtype.float64
|
||||
else:
|
||||
raise ValueError("unknown dtype")
|
||||
return st.just(sig)
|
||||
|
||||
|
||||
def constant_signal(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]:
|
||||
return gen_element_fn(dtype).map(lambda val: argus.signal(dtype, data=val))
|
||||
def constant_signal(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
|
||||
return gen_element_fn(dtype_).map(lambda val: argus.signal(dtype_, data=val))
|
||||
|
||||
|
||||
@composite
|
||||
|
|
@ -87,16 +85,16 @@ def draw_index(draw: st.DrawFn, vec: List) -> int:
|
|||
return draw(st.just(0))
|
||||
|
||||
|
||||
def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], DType]]:
|
||||
def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], dtype]]:
|
||||
return st.one_of(
|
||||
list(map(st.just, [DType.Bool, DType.UnsignedInt, DType.Int, DType.Float, bool, int, float])), # type: ignore[arg-type]
|
||||
list(map(st.just, [dtype.bool_, dtype.uint64, dtype.int64, dtype.float64, bool, int, float])), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
@given(st.data())
|
||||
def test_correct_constant_signals(data: st.DataObject) -> None:
|
||||
dtype = data.draw(gen_dtype())
|
||||
signal = data.draw(constant_signal(dtype))
|
||||
dtype_ = data.draw(gen_dtype())
|
||||
signal = data.draw(constant_signal(dtype_))
|
||||
|
||||
assert not signal.is_empty()
|
||||
assert signal.start_time is None
|
||||
|
|
@ -105,11 +103,11 @@ def test_correct_constant_signals(data: st.DataObject) -> None:
|
|||
|
||||
@given(st.data())
|
||||
def test_correctly_create_signals(data: st.DataObject) -> None:
|
||||
dtype = data.draw(gen_dtype())
|
||||
xs = data.draw(gen_samples(min_size=0, max_size=100, dtype=dtype))
|
||||
dtype_ = data.draw(gen_dtype())
|
||||
xs = data.draw(gen_samples(min_size=0, max_size=100, dtype_=dtype_))
|
||||
|
||||
note(f"Samples: {gen_samples}")
|
||||
signal = argus.signal(dtype, data=xs)
|
||||
signal = argus.signal(dtype_, data=xs)
|
||||
if len(xs) > 0:
|
||||
expected_start_time = xs[0][0]
|
||||
expected_end_time = xs[-1][0]
|
||||
|
|
@ -132,7 +130,7 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
|
|||
|
||||
# generate one more sample
|
||||
new_time = actual_end_time + 1
|
||||
new_value = data.draw(gen_element_fn(dtype))
|
||||
new_value = data.draw(gen_element_fn(dtype_))
|
||||
signal.push(new_time, new_value) # type: ignore[arg-type]
|
||||
|
||||
get_val = signal.at(new_time)
|
||||
|
|
@ -148,8 +146,8 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
|
|||
|
||||
@given(st.data())
|
||||
def test_signal_create_should_fail(data: st.DataObject) -> None:
|
||||
dtype = data.draw(gen_dtype())
|
||||
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype=dtype))
|
||||
dtype_ = data.draw(gen_dtype())
|
||||
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype_=dtype_))
|
||||
a = data.draw(draw_index(xs))
|
||||
b = data.draw(draw_index(xs))
|
||||
assume(a != b)
|
||||
|
|
@ -161,24 +159,24 @@ def test_signal_create_should_fail(data: st.DataObject) -> None:
|
|||
xs[b], xs[a] = xs[a], xs[b]
|
||||
|
||||
with pytest.raises(RuntimeError, match=r"trying to create a non-monotonically signal.+"):
|
||||
_ = argus.signal(dtype, data=xs)
|
||||
_ = argus.signal(dtype_, data=xs)
|
||||
|
||||
|
||||
@given(st.data())
|
||||
def test_push_to_empty_signal(data: st.DataObject) -> None:
|
||||
dtype = data.draw(gen_dtype())
|
||||
sig = data.draw(empty_signal(dtype=dtype))
|
||||
dtype_ = data.draw(gen_dtype())
|
||||
sig = data.draw(empty_signal(dtype_=dtype_))
|
||||
assert sig.is_empty()
|
||||
element = data.draw(gen_element_fn(dtype))
|
||||
element = data.draw(gen_element_fn(dtype_))
|
||||
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
|
||||
sig.push(0.0, element) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@given(st.data())
|
||||
def test_push_to_constant_signal(data: st.DataObject) -> None:
|
||||
dtype = data.draw(gen_dtype())
|
||||
sig = data.draw(constant_signal(dtype=dtype))
|
||||
dtype_ = data.draw(gen_dtype())
|
||||
sig = data.draw(constant_signal(dtype_=dtype_))
|
||||
assert not sig.is_empty()
|
||||
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype=dtype))[0]
|
||||
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype_=dtype_))[0]
|
||||
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
|
||||
sig.push(*sample) # type: ignore[attr-defined]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue