refactor(pyargus): data type name

This commit is contained in:
Anand Balakrishnan 2023-09-07 15:43:04 -07:00
parent 8093ab7c9f
commit e2cfe3da56
No known key found for this signature in database
5 changed files with 85 additions and 81 deletions

View file

@ -6,22 +6,20 @@ from hypothesis import strategies as st
from hypothesis.strategies import SearchStrategy, composite
import argus
from argus import DType
AllowedDtype = Union[bool, int, float]
from argus import AllowedDtype, dtype
def gen_element_fn(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[AllowedDtype]:
new_dtype = DType.convert(dtype)
if new_dtype == DType.Bool:
def gen_element_fn(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[AllowedDtype]:
new_dtype = dtype.convert(dtype_)
if new_dtype == dtype.bool_:
return st.booleans()
elif new_dtype == DType.Int:
elif new_dtype == dtype.int64:
size = 2**64
return st.integers(min_value=(-size // 2), max_value=((size - 1) // 2))
elif new_dtype == DType.UnsignedInt:
elif new_dtype == dtype.uint64:
size = 2**64
return st.integers(min_value=0, max_value=(size - 1))
elif new_dtype == DType.Float:
elif new_dtype == dtype.float64:
return st.floats(
width=64,
allow_nan=False,
@ -29,18 +27,18 @@ def gen_element_fn(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[Al
allow_subnormal=False,
)
else:
raise ValueError(f"invalid dtype {dtype}")
raise ValueError(f"invalid dtype {dtype_}")
@composite
def gen_samples(
draw: st.DrawFn, *, min_size: int, max_size: int, dtype: Union[Type[AllowedDtype], DType]
draw: st.DrawFn, min_size: int, max_size: int, dtype_: Union[Type[AllowedDtype], dtype]
) -> List[Tuple[float, AllowedDtype]]:
"""
Generate arbitrary samples for a signal where the time stamps are strictly
monotonically increasing
"""
elements = gen_element_fn(dtype)
elements = gen_element_fn(dtype_)
values = draw(st.lists(elements, min_size=min_size, max_size=max_size))
xs = draw(
st.lists(
@ -55,28 +53,28 @@ def gen_samples(
return xs
def empty_signal(*, dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]:
new_dtype: DType = DType.convert(dtype)
def empty_signal(*, dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
new_dtype: dtype = dtype.convert(dtype_)
sig: argus.Signal
if new_dtype == DType.Bool:
if new_dtype == dtype.bool_:
sig = argus.BoolSignal()
assert sig.kind is bool
elif new_dtype == DType.UnsignedInt:
assert sig.kind == dtype.bool_
elif new_dtype == dtype.uint64:
sig = argus.UnsignedIntSignal()
assert sig.kind is int
elif new_dtype == DType.Int:
assert sig.kind == dtype.uint64
elif new_dtype == dtype.int64:
sig = argus.IntSignal()
assert sig.kind is int
elif new_dtype == DType.Float:
assert sig.kind == dtype.int64
elif new_dtype == dtype.float64:
sig = argus.FloatSignal()
assert sig.kind is float
assert sig.kind == dtype.float64
else:
raise ValueError("unknown dtype")
return st.just(sig)
def constant_signal(dtype: Union[Type[AllowedDtype], DType]) -> SearchStrategy[argus.Signal]:
return gen_element_fn(dtype).map(lambda val: argus.signal(dtype, data=val))
def constant_signal(dtype_: Union[Type[AllowedDtype], dtype]) -> SearchStrategy[argus.Signal]:
return gen_element_fn(dtype_).map(lambda val: argus.signal(dtype_, data=val))
@composite
@ -87,16 +85,16 @@ def draw_index(draw: st.DrawFn, vec: List) -> int:
return draw(st.just(0))
def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], DType]]:
def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], dtype]]:
return st.one_of(
list(map(st.just, [DType.Bool, DType.UnsignedInt, DType.Int, DType.Float, bool, int, float])), # type: ignore[arg-type]
list(map(st.just, [dtype.bool_, dtype.uint64, dtype.int64, dtype.float64, bool, int, float])), # type: ignore[arg-type]
)
@given(st.data())
def test_correct_constant_signals(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype())
signal = data.draw(constant_signal(dtype))
dtype_ = data.draw(gen_dtype())
signal = data.draw(constant_signal(dtype_))
assert not signal.is_empty()
assert signal.start_time is None
@ -105,11 +103,11 @@ def test_correct_constant_signals(data: st.DataObject) -> None:
@given(st.data())
def test_correctly_create_signals(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=0, max_size=100, dtype=dtype))
dtype_ = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=0, max_size=100, dtype_=dtype_))
note(f"Samples: {gen_samples}")
signal = argus.signal(dtype, data=xs)
signal = argus.signal(dtype_, data=xs)
if len(xs) > 0:
expected_start_time = xs[0][0]
expected_end_time = xs[-1][0]
@ -132,7 +130,7 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
# generate one more sample
new_time = actual_end_time + 1
new_value = data.draw(gen_element_fn(dtype))
new_value = data.draw(gen_element_fn(dtype_))
signal.push(new_time, new_value) # type: ignore[arg-type]
get_val = signal.at(new_time)
@ -148,8 +146,8 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
@given(st.data())
def test_signal_create_should_fail(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype=dtype))
dtype_ = data.draw(gen_dtype())
xs = data.draw(gen_samples(min_size=10, max_size=100, dtype_=dtype_))
a = data.draw(draw_index(xs))
b = data.draw(draw_index(xs))
assume(a != b)
@ -161,24 +159,24 @@ def test_signal_create_should_fail(data: st.DataObject) -> None:
xs[b], xs[a] = xs[a], xs[b]
with pytest.raises(RuntimeError, match=r"trying to create a non-monotonically signal.+"):
_ = argus.signal(dtype, data=xs)
_ = argus.signal(dtype_, data=xs)
@given(st.data())
def test_push_to_empty_signal(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype())
sig = data.draw(empty_signal(dtype=dtype))
dtype_ = data.draw(gen_dtype())
sig = data.draw(empty_signal(dtype_=dtype_))
assert sig.is_empty()
element = data.draw(gen_element_fn(dtype))
element = data.draw(gen_element_fn(dtype_))
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
sig.push(0.0, element) # type: ignore[attr-defined]
@given(st.data())
def test_push_to_constant_signal(data: st.DataObject) -> None:
dtype = data.draw(gen_dtype())
sig = data.draw(constant_signal(dtype=dtype))
dtype_ = data.draw(gen_dtype())
sig = data.draw(constant_signal(dtype_=dtype_))
assert not sig.is_empty()
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype=dtype))[0]
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype_=dtype_))[0]
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
sig.push(*sample) # type: ignore[attr-defined]