test(pyargus): improve test coverage

This commit is contained in:
Anand Balakrishnan 2023-09-07 13:32:27 -07:00
parent 7129177ca0
commit 4942a78899
No known key found for this signature in database
5 changed files with 111 additions and 47 deletions

View file

@ -5,7 +5,7 @@ from typing import List, Optional, Tuple, Type, Union
from argus import _argus from argus import _argus
from argus._argus import DType as DType from argus._argus import DType as DType
from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt
from argus.signals import BoolSignal, FloatSignal, IntSignal, UnsignedIntSignal from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
try: try:
__doc__ = _argus.__doc__ __doc__ = _argus.__doc__
@ -96,4 +96,5 @@ __all__ = [
"declare_var", "declare_var",
"literal", "literal",
"signal", "signal",
"Signal",
] ]

View file

@ -1,4 +1,4 @@
from typing import ClassVar, Protocol, TypeVar, final from typing import ClassVar, Generic, Protocol, TypeVar, final
from typing_extensions import Self from typing_extensions import Self
@ -136,7 +136,7 @@ class DType:
_SignalKind = TypeVar("_SignalKind", bool, int, float, covariant=True) _SignalKind = TypeVar("_SignalKind", bool, int, float, covariant=True)
class Signal(Protocol[_SignalKind]): class Signal(Generic[_SignalKind], Protocol):
def is_empty(self) -> bool: ... def is_empty(self) -> bool: ...
@property @property
def start_time(self) -> float | None: ... def start_time(self) -> float | None: ...

View file

@ -43,6 +43,7 @@ testpaths = ["tests"]
[tool.mypy] [tool.mypy]
# ignore_missing_imports = true # ignore_missing_imports = true
show_error_codes = true show_error_codes = true
plugins = ["numpy.typing.mypy_plugin"]
[tool.ruff] [tool.ruff]
line-length = 127 line-length = 127

View file

@ -1,5 +1,3 @@
use std::time::Duration;
use argus_core::signals::interpolation::Linear; use argus_core::signals::interpolation::Linear;
use argus_core::signals::Signal; use argus_core::signals::Signal;
use pyo3::prelude::*; use pyo3::prelude::*;
@ -119,11 +117,6 @@ macro_rules! impl_signals {
(Self, Self::super_type(Signal::<$ty>::new().into())) (Self, Self::super_type(Signal::<$ty>::new().into()))
} }
#[pyo3(signature = ())]
fn __init__(self_: PyRef<'_, Self>) -> PyRef<'_, Self> {
self_
}
/// Create a new signal with constant value /// Create a new signal with constant value
#[classmethod] #[classmethod]
fn constant(_: &PyType, py: Python<'_>, value: $ty) -> PyResult<Py<Self>> { fn constant(_: &PyType, py: Python<'_>, value: $ty) -> PyResult<Py<Self>> {
@ -138,7 +131,7 @@ macro_rules! impl_signals {
fn from_samples(_: &PyType, samples: Vec<(f64, $ty)>) -> PyResult<Py<Self>> { fn from_samples(_: &PyType, samples: Vec<(f64, $ty)>) -> PyResult<Py<Self>> {
let ret: Signal::<$ty> = Signal::<$ty>::try_from_iter(samples let ret: Signal::<$ty> = Signal::<$ty>::try_from_iter(samples
.into_iter() .into_iter()
.map(|(t, v)| (Duration::try_from_secs_f64(t).unwrap_or_else(|err| panic!("Value = {}, {}", t, err)), v)) .map(|(t, v)| (core::time::Duration::try_from_secs_f64(t).unwrap_or_else(|err| panic!("Value = {}, {}", t, err)), v))
).map_err(PyArgusError::from)?; ).map_err(PyArgusError::from)?;
Python::with_gil(|py| { Python::with_gil(|py| {
Py::new( Py::new(
@ -153,7 +146,7 @@ macro_rules! impl_signals {
fn push(mut self_: PyRefMut<'_, Self>, time: f64, value: $ty) -> Result<(), PyArgusError> { fn push(mut self_: PyRefMut<'_, Self>, time: f64, value: $ty) -> Result<(), PyArgusError> {
let super_: &mut PySignal = self_.as_mut(); let super_: &mut PySignal = self_.as_mut();
let signal: &mut Signal<$ty> = (&mut super_.signal).try_into().unwrap(); let signal: &mut Signal<$ty> = (&mut super_.signal).try_into().unwrap();
signal.push(Duration::from_secs_f64(time), value)?; signal.push(core::time::Duration::from_secs_f64(time), value)?;
Ok(()) Ok(())
} }

View file

@ -1,34 +1,46 @@
from typing import List, Tuple, Type, Union from typing import List, Tuple, Type, Union
import pytest import pytest
from hypothesis import Verbosity, given, note, settings from hypothesis import assume, given, note
from hypothesis import strategies as st from hypothesis import strategies as st
from hypothesis.strategies import SearchStrategy, composite from hypothesis.strategies import SearchStrategy, composite
import argus import argus
from argus import DType
AllowedDtype = Union[bool, int, float] AllowedDtype = Union[bool, int, float]
def gen_element_fn(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[AllowedDtype]:
new_dtype = DType.convert(dtype)
if new_dtype == DType.Bool:
return st.booleans()
elif new_dtype == DType.Int:
size = 2**64
return st.integers(min_value=(-size // 2), max_value=((size - 1) // 2))
elif new_dtype == DType.UnsignedInt:
size = 2**64
return st.integers(min_value=0, max_value=(size - 1))
elif new_dtype == DType.Float:
return st.floats(
width=64,
allow_nan=False,
allow_infinity=False,
allow_subnormal=False,
)
else:
raise ValueError(f"invalid dtype {dtype}")
@composite @composite
def gen_samples( def gen_samples(
draw: st.DrawFn, *, min_size: int, max_size: int, dtype: Type[AllowedDtype] draw: st.DrawFn, *, min_size: int, max_size: int, dtype: Union[Type[AllowedDtype], DType]
) -> List[Tuple[float, AllowedDtype]]: ) -> List[Tuple[float, AllowedDtype]]:
""" """
Generate arbitrary samples for a signal where the time stamps are strictly Generate arbitrary samples for a signal where the time stamps are strictly
monotonically increasing monotonically increasing
""" """
elements: st.SearchStrategy[AllowedDtype] elements = gen_element_fn(dtype)
if dtype == bool:
elements = st.booleans()
elif dtype == int:
size = 2**64
elements = st.integers(min_value=(-size // 2), max_value=((size - 1) // 2))
elif dtype == float:
elements = st.floats(width=64)
else:
raise ValueError(f"invalid dtype {dtype}")
values = draw(st.lists(elements, min_size=min_size, max_size=max_size)) values = draw(st.lists(elements, min_size=min_size, max_size=max_size))
xs = draw( xs = draw(
st.lists( st.lists(
@ -43,6 +55,30 @@ def gen_samples(
return xs return xs
def empty_signal(*, dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]:
new_dtype: DType = DType.convert(dtype)
sig: argus.Signal
if new_dtype == DType.Bool:
sig = argus.BoolSignal()
assert sig.kind is bool
elif new_dtype == DType.UnsignedInt:
sig = argus.UnsignedIntSignal()
assert sig.kind is int
elif new_dtype == DType.Int:
sig = argus.IntSignal()
assert sig.kind is int
elif new_dtype == DType.Float:
sig = argus.FloatSignal()
assert sig.kind is float
else:
raise ValueError("unknown dtype")
return st.just(sig)
def constant_signal(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]:
return gen_element_fn(dtype).map(lambda val: argus.signal(dtype, data=val))
@composite @composite
def draw_index(draw: st.DrawFn, vec: List) -> int: def draw_index(draw: st.DrawFn, vec: List) -> int:
if len(vec) > 0: if len(vec) > 0:
@ -51,8 +87,20 @@ def draw_index(draw: st.DrawFn, vec: List) -> int:
return draw(st.just(0)) return draw(st.just(0))
def gen_dtype() -> SearchStrategy[Type[AllowedDtype]]: def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], DType]]:
return st.one_of(st.just(bool), st.just(int), st.just(float)) return st.one_of(
list(map(st.just, [DType.Bool, DType.UnsignedInt, DType.Int, DType.Float, bool, int, float])), # type: ignore[arg-type]
)
@given(st.data())
def test_correct_constant_signals(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype())
signal = data.draw(constant_signal(dtype))
assert not signal.is_empty()
assert signal.start_time is None
assert signal.end_time is None
@given(st.data()) @given(st.data())
@ -74,36 +122,37 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
assert actual_end_time is not None assert actual_end_time is not None
assert actual_end_time == expected_end_time assert actual_end_time == expected_end_time
a = data.draw(draw_index(xs))
assert a < len(xs)
at, expected_val = xs[a]
actual_val = signal.at(at)
assert actual_val is not None
assert actual_val == expected_val
# generate one more sample
new_time = actual_end_time + 1
new_value = data.draw(gen_element_fn(dtype))
signal.push(new_time, new_value) # type: ignore[arg-type]
get_val = signal.at(new_time)
assert get_val is not None
assert get_val == new_value
else: else:
assert signal.is_empty() assert signal.is_empty()
assert signal.start_time is None
assert signal.end_time is None
assert signal.at(0) is None assert signal.at(0) is None
@settings(verbosity=Verbosity.verbose)
@given(st.data())
def test_signal_at(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype=dtype))
a = data.draw(draw_index(xs))
assert len(xs) > 2
assert a < len(xs)
signal = argus.signal(dtype, data=xs)
at, expected_val = xs[a]
actual_val = signal.at(at)
assert actual_val is not None
assert actual_val == expected_val
@given(st.data()) @given(st.data())
def test_signal_create_should_fail(data: st.DataObject) -> None: def test_signal_create_should_fail(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype()) dtype = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype=dtype)) xs = data.draw(gen_samples(min_size=10, max_size=100, dtype=dtype))
a = data.draw(draw_index(xs)) a = data.draw(draw_index(xs))
b = data.draw(draw_index(xs)) b = data.draw(draw_index(xs))
assume(a != b)
assert len(xs) > 2 assert len(xs) > 2
assert a < len(xs) assert a < len(xs)
@ -111,5 +160,25 @@ def test_signal_create_should_fail(data: st.DataObject) -> None:
# Swap two indices in the samples # Swap two indices in the samples
xs[b], xs[a] = xs[a], xs[b] xs[b], xs[a] = xs[a], xs[b]
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError, match=r"trying to create a non-monotonically signal.+"):
_ = argus.signal(dtype, data=xs) _ = argus.signal(dtype, data=xs)
@given(st.data())
def test_push_to_empty_signal(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype())
sig = data.draw(empty_signal(dtype=dtype))
assert sig.is_empty()
element = data.draw(gen_element_fn(dtype))
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
sig.push(0.0, element) # type: ignore[attr-defined]
@given(st.data())
def test_push_to_constant_signal(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype())
sig = data.draw(constant_signal(dtype=dtype))
assert not sig.is_empty()
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype=dtype))[0]
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
sig.push(*sample) # type: ignore[attr-defined]