argus/pyargus/tests/test_signals.py
Anand Balakrishnan d39e3d3e12
feat!(pyargus): simplify the API surface
- Get rid of helper functions. It is not that much more verbose to
  create signals with `argus.FloatSignal(...)` than
  `argus.signal(..., dtype=argus.dtype.float64`).

- Make the package hierarchy flat: everything is under `argus`. If this
  is an issue, it can be changed in the future.

- Add type hints for interval types.
2023-10-05 15:28:59 -07:00

209 lines
7.1 KiB
Python

import typing
from typing import List, Tuple, Type, Union
import pytest
from hypothesis import assume, given
from hypothesis import strategies as st
from hypothesis.strategies import SearchStrategy, composite
import argus
from argus import AllowedDtype, dtype
def gen_element_fn(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[AllowedDtype]:
new_dtype = dtype.convert(dtype_)
if new_dtype == dtype.bool_:
return st.booleans()
elif new_dtype == dtype.int64:
size = 2**64
return st.integers(min_value=(-size // 2), max_value=((size - 1) // 2))
elif new_dtype == dtype.uint64:
size = 2**64
return st.integers(min_value=0, max_value=(size - 1))
elif new_dtype == dtype.float64:
return st.floats(
width=64,
allow_nan=False,
allow_infinity=False,
allow_subnormal=False,
)
else:
raise ValueError(f"invalid dtype {dtype_}")
@composite
def gen_samples(
draw: st.DrawFn, min_size: int, max_size: int, dtype_: Union[Type[AllowedDtype], dtype]
) -> List[Tuple[float, AllowedDtype]]:
"""
Generate arbitrary samples for a signal where the time stamps are strictly
monotonically increasing
"""
elements = gen_element_fn(dtype_)
values = draw(st.lists(elements, min_size=min_size, max_size=max_size))
xs = draw(
st.lists(
st.integers(min_value=0, max_value=2**32 - 1),
unique=True,
min_size=len(values),
max_size=len(values),
)
.map(lambda t: map(float, sorted(set(t))))
.map(lambda t: list(zip(t, values)))
)
return xs
def empty_signal(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
new_dtype: dtype = dtype.convert(dtype_)
sig: argus.Signal
if new_dtype == dtype.bool_:
sig = argus.BoolSignal()
assert sig.kind == dtype.bool_
elif new_dtype == dtype.uint64:
sig = argus.UnsignedIntSignal()
assert sig.kind == dtype.uint64
elif new_dtype == dtype.int64:
sig = argus.IntSignal()
assert sig.kind == dtype.int64
elif new_dtype == dtype.float64:
sig = argus.FloatSignal()
assert sig.kind == dtype.float64
else:
raise ValueError("unknown dtype")
return st.just(sig)
def constant_signal(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
element = gen_element_fn(dtype_)
dtype_ = dtype.convert(dtype_)
if dtype_ == dtype.bool_:
return element.map(lambda val: argus.BoolSignal.constant(typing.cast(bool, val)))
if dtype_ == dtype.uint64:
return element.map(lambda val: argus.UnsignedIntSignal.constant(typing.cast(int, val)))
if dtype_ == dtype.int64:
return element.map(lambda val: argus.IntSignal.constant(typing.cast(int, val)))
if dtype_ == dtype.float64:
return element.map(lambda val: argus.FloatSignal.constant(typing.cast(float, val)))
raise ValueError("unsupported data type for signal")
def sampled_signal(xs: List[Tuple[float, AllowedDtype]], dtype_: Union[Type[AllowedDtype], dtype]) -> argus.Signal:
dtype_ = dtype.convert(dtype_)
if dtype_ == dtype.bool_:
return argus.BoolSignal.from_samples(typing.cast(List[Tuple[float, bool]], xs))
if dtype_ == dtype.uint64:
return argus.UnsignedIntSignal.from_samples(typing.cast(List[Tuple[float, int]], xs))
if dtype_ == dtype.int64:
return argus.IntSignal.from_samples(typing.cast(List[Tuple[float, int]], xs))
if dtype_ == dtype.float64:
return argus.FloatSignal.from_samples(typing.cast(List[Tuple[float, float]], xs))
raise ValueError("unsupported data type for signal")
@composite
def draw_index(draw: st.DrawFn, vec: List) -> int:
if len(vec) > 0:
return draw(st.integers(min_value=0, max_value=len(vec) - 1))
else:
return draw(st.just(0))
def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], dtype]]:
return st.one_of(
list(map(st.just, [dtype.bool_, dtype.uint64, dtype.int64, dtype.float64, bool, int, float])), # type: ignore[arg-type]
)
@given(st.data())
def test_correct_constant_signals(data: st.DataObject) -> None:
dtype_ = data.draw(gen_dtype())
signal = data.draw(constant_signal(dtype_))
assert isinstance(signal, argus.Signal)
assert not signal.is_empty()
assert signal.start_time is None
assert signal.end_time is None
@given(st.data())
def test_correctly_create_signals(data: st.DataObject) -> None:
dtype_ = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=0, max_size=100, dtype_=dtype_))
signal = sampled_signal(xs, dtype_)
assert isinstance(signal, argus.Signal)
if len(xs) > 0:
expected_start_time = xs[0][0]
expected_end_time = xs[-1][0]
actual_start_time = signal.start_time
actual_end_time = signal.end_time
assert actual_start_time is not None
assert actual_start_time == expected_start_time
assert actual_end_time is not None
assert actual_end_time == expected_end_time
a = data.draw(draw_index(xs))
assert a < len(xs)
at, expected_val = xs[a]
actual_val = signal.at(at)
assert actual_val is not None
assert actual_val == expected_val
# generate one more sample
new_time = actual_end_time + 1
new_value = data.draw(gen_element_fn(dtype_))
signal.push(new_time, new_value) # type: ignore[arg-type]
get_val = signal.at(new_time)
assert get_val is not None
assert get_val == new_value
else:
assert signal.is_empty()
assert signal.start_time is None
assert signal.end_time is None
assert signal.at(0) is None
@given(st.data())
def test_signal_create_should_fail(data: st.DataObject) -> None:
dtype_ = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype_=dtype_))
a = data.draw(draw_index(xs))
b = data.draw(draw_index(xs))
assume(a != b)
assert len(xs) > 2
assert a < len(xs)
assert b < len(xs)
# Swap two indices in the samples
xs[b], xs[a] = xs[a], xs[b]
with pytest.raises(RuntimeError, match=r"trying to create a non-monotonically signal.+"):
_ = sampled_signal(xs, dtype_)
@given(st.data())
def test_push_to_empty_signal(data: st.DataObject) -> None:
dtype_ = data.draw(gen_dtype())
signal = data.draw(empty_signal(dtype_=dtype_))
assert isinstance(signal, argus.Signal)
assert signal.is_empty()
element = data.draw(gen_element_fn(dtype_))
signal.push(0.0, element)
assert signal.at(0.0) == element
@given(st.data())
def test_push_to_constant_signal(data: st.DataObject) -> None:
dtype_ = data.draw(gen_dtype())
signal = data.draw(constant_signal(dtype_=dtype_))
assert isinstance(signal, argus.Signal)
assert not signal.is_empty()
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype_=dtype_))[0]
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
signal.push(*sample) # type: ignore[attr-defined]