refactor(pyargus): define a runtime checkable Signal protocol
This commit is contained in:
parent
8027f86213
commit
3d6157e03a
6 changed files with 70 additions and 17 deletions
|
|
@ -89,10 +89,16 @@ def ruff(session: nox.Session):
|
||||||
def mypy(session: nox.Session):
|
def mypy(session: nox.Session):
|
||||||
session.conda_install("mypy", "typing-extensions", "pytest", "hypothesis", "numpy")
|
session.conda_install("mypy", "typing-extensions", "pytest", "hypothesis", "numpy")
|
||||||
session.env.update(ENV)
|
session.env.update(ENV)
|
||||||
|
|
||||||
with session.chdir(CURRENT_DIR / "pyargus"):
|
with session.chdir(CURRENT_DIR / "pyargus"):
|
||||||
session.install("-e", ".")
|
session.install("-e", ".")
|
||||||
session.run("mypy", ".")
|
session.run("mypy", ".")
|
||||||
session.run("stubtest", "argus")
|
session.run(
|
||||||
|
"stubtest",
|
||||||
|
"argus",
|
||||||
|
"--allowlist",
|
||||||
|
str(CURRENT_DIR / "pyargus/stubtest_allow.txt"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@nox.session
|
@nox.session
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import ClassVar, Generic, Protocol, TypeVar, final
|
from typing import ClassVar, Protocol, final
|
||||||
|
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
|
@ -134,9 +134,7 @@ class dtype: # noqa: N801
|
||||||
def __eq__(self, other: object) -> bool: ...
|
def __eq__(self, other: object) -> bool: ...
|
||||||
def __int__(self) -> int: ...
|
def __int__(self) -> int: ...
|
||||||
|
|
||||||
_SignalKind = TypeVar("_SignalKind", bool, int, float, covariant=True)
|
class Signal:
|
||||||
|
|
||||||
class Signal(Generic[_SignalKind], Protocol):
|
|
||||||
def is_empty(self) -> bool: ...
|
def is_empty(self) -> bool: ...
|
||||||
@property
|
@property
|
||||||
def start_time(self) -> float | None: ...
|
def start_time(self) -> float | None: ...
|
||||||
|
|
@ -146,16 +144,16 @@ class Signal(Generic[_SignalKind], Protocol):
|
||||||
def kind(self) -> dtype: ...
|
def kind(self) -> dtype: ...
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class BoolSignal(Signal[bool]):
|
class BoolSignal(Signal):
|
||||||
@classmethod
|
@classmethod
|
||||||
def constant(cls, value: bool) -> Self: ...
|
def constant(cls, value: bool) -> Self: ...
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_samples(cls, samples: list[tuple[float, bool]]) -> Self: ...
|
def from_samples(cls, samples: list[tuple[float, bool]]) -> Self: ...
|
||||||
def push(self, time: float, value: bool) -> None: ...
|
def push(self, time: float, value: bool) -> None: ...
|
||||||
def at(self, time: float) -> _SignalKind | None: ...
|
def at(self, time: float) -> bool | None: ...
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class IntSignal(Signal[int]):
|
class IntSignal(Signal):
|
||||||
@classmethod
|
@classmethod
|
||||||
def constant(cls, value: int) -> Self: ...
|
def constant(cls, value: int) -> Self: ...
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -164,7 +162,7 @@ class IntSignal(Signal[int]):
|
||||||
def at(self, time: float) -> int | None: ...
|
def at(self, time: float) -> int | None: ...
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class UnsignedIntSignal(Signal[int]):
|
class UnsignedIntSignal(Signal):
|
||||||
@classmethod
|
@classmethod
|
||||||
def constant(cls, value: int) -> Self: ...
|
def constant(cls, value: int) -> Self: ...
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -173,7 +171,7 @@ class UnsignedIntSignal(Signal[int]):
|
||||||
def at(self, time: float) -> int | None: ...
|
def at(self, time: float) -> int | None: ...
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class FloatSignal(Signal[float]):
|
class FloatSignal(Signal):
|
||||||
@classmethod
|
@classmethod
|
||||||
def constant(cls, value: float) -> Self: ...
|
def constant(cls, value: float) -> Self: ...
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,43 @@
|
||||||
from argus._argus import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
|
from typing import List, Optional, Protocol, Tuple, TypeVar, runtime_checkable
|
||||||
|
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from argus._argus import BoolSignal, FloatSignal, IntSignal, UnsignedIntSignal, dtype
|
||||||
|
|
||||||
|
T = TypeVar("T", bool, int, float)
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Signal(Protocol[T]):
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_time(self) -> Optional[float]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_time(self) -> Optional[float]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kind(self) -> dtype:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def constant(cls, value: T) -> Self:
|
||||||
|
...
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_samples(cls, samples: List[Tuple[float, T]]) -> Self:
|
||||||
|
...
|
||||||
|
|
||||||
|
def push(self, time: float, value: T) -> None:
|
||||||
|
...
|
||||||
|
|
||||||
|
def at(self, time: float) -> Optional[T]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Signal",
|
"Signal",
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ addopts = ["--import-mode=importlib"]
|
||||||
testpaths = ["tests"]
|
testpaths = ["tests"]
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
|
packages = ["argus"]
|
||||||
# ignore_missing_imports = true
|
# ignore_missing_imports = true
|
||||||
show_error_codes = true
|
show_error_codes = true
|
||||||
plugins = ["numpy.typing.mypy_plugin"]
|
plugins = ["numpy.typing.mypy_plugin"]
|
||||||
|
|
|
||||||
5
pyargus/stubtest_allow.txt
Normal file
5
pyargus/stubtest_allow.txt
Normal file
|
|
@ -0,0 +1,5 @@
|
||||||
|
argus.signals.Protocol
|
||||||
|
argus.signals.TypeVar.__bound__
|
||||||
|
argus.signals.TypeVar.__constraints__
|
||||||
|
argus.signals.TypeVar.__contravariant__
|
||||||
|
argus.signals.TypeVar.__covariant__
|
||||||
|
|
@ -95,6 +95,7 @@ def gen_dtype() -> SearchStrategy[Union[Type[AllowedDtype], dtype]]:
|
||||||
def test_correct_constant_signals(data: st.DataObject) -> None:
|
def test_correct_constant_signals(data: st.DataObject) -> None:
|
||||||
dtype_ = data.draw(gen_dtype())
|
dtype_ = data.draw(gen_dtype())
|
||||||
signal = data.draw(constant_signal(dtype_))
|
signal = data.draw(constant_signal(dtype_))
|
||||||
|
assert isinstance(signal, argus.Signal)
|
||||||
|
|
||||||
assert not signal.is_empty()
|
assert not signal.is_empty()
|
||||||
assert signal.start_time is None
|
assert signal.start_time is None
|
||||||
|
|
@ -108,6 +109,7 @@ def test_correctly_create_signals(data: st.DataObject) -> None:
|
||||||
|
|
||||||
note(f"Samples: {gen_samples}")
|
note(f"Samples: {gen_samples}")
|
||||||
signal = argus.signal(dtype_, data=xs)
|
signal = argus.signal(dtype_, data=xs)
|
||||||
|
assert isinstance(signal, argus.Signal)
|
||||||
if len(xs) > 0:
|
if len(xs) > 0:
|
||||||
expected_start_time = xs[0][0]
|
expected_start_time = xs[0][0]
|
||||||
expected_end_time = xs[-1][0]
|
expected_end_time = xs[-1][0]
|
||||||
|
|
@ -165,18 +167,20 @@ def test_signal_create_should_fail(data: st.DataObject) -> None:
|
||||||
@given(st.data())
|
@given(st.data())
|
||||||
def test_push_to_empty_signal(data: st.DataObject) -> None:
|
def test_push_to_empty_signal(data: st.DataObject) -> None:
|
||||||
dtype_ = data.draw(gen_dtype())
|
dtype_ = data.draw(gen_dtype())
|
||||||
sig = data.draw(empty_signal(dtype_=dtype_))
|
signal = data.draw(empty_signal(dtype_=dtype_))
|
||||||
assert sig.is_empty()
|
assert isinstance(signal, argus.Signal)
|
||||||
|
assert signal.is_empty()
|
||||||
element = data.draw(gen_element_fn(dtype_))
|
element = data.draw(gen_element_fn(dtype_))
|
||||||
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
|
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
|
||||||
sig.push(0.0, element) # type: ignore[attr-defined]
|
signal.push(0.0, element) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
@given(st.data())
|
@given(st.data())
|
||||||
def test_push_to_constant_signal(data: st.DataObject) -> None:
|
def test_push_to_constant_signal(data: st.DataObject) -> None:
|
||||||
dtype_ = data.draw(gen_dtype())
|
dtype_ = data.draw(gen_dtype())
|
||||||
sig = data.draw(constant_signal(dtype_=dtype_))
|
signal = data.draw(constant_signal(dtype_=dtype_))
|
||||||
assert not sig.is_empty()
|
assert isinstance(signal, argus.Signal)
|
||||||
|
assert not signal.is_empty()
|
||||||
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype_=dtype_))[0]
|
sample = data.draw(gen_samples(min_size=1, max_size=1, dtype_=dtype_))[0]
|
||||||
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
|
with pytest.raises(RuntimeError, match="cannot push value to non-sampled signal"):
|
||||||
sig.push(*sample) # type: ignore[attr-defined]
|
signal.push(*sample) # type: ignore[attr-defined]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue