refactor(pyargus): data type name
This commit is contained in:
parent
8093ab7c9f
commit
e2cfe3da56
5 changed files with 85 additions and 81 deletions
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||
from typing import List, Optional, Tuple, Type, Union
|
||||
|
||||
from argus import _argus
|
||||
from argus._argus import DType as DType
|
||||
from argus._argus import dtype
|
||||
from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt
|
||||
from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
|
||||
|
||||
|
|
@ -15,25 +15,19 @@ except AttributeError:
|
|||
AllowedDtype = Union[bool, int, float]
|
||||
|
||||
|
||||
def declare_var(name: str, dtype: Union[DType, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]:
|
||||
def declare_var(name: str, dtype_: Union[dtype, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]:
|
||||
"""Declare a variable with the given name and type"""
|
||||
if isinstance(dtype, type):
|
||||
if dtype == bool:
|
||||
dtype = DType.Bool
|
||||
elif dtype == int:
|
||||
dtype = DType.Int
|
||||
elif dtype == float:
|
||||
dtype = DType.Float
|
||||
dtype_ = dtype.convert(dtype_)
|
||||
|
||||
if dtype == DType.Bool:
|
||||
if dtype_ == dtype.bool_:
|
||||
return VarBool(name)
|
||||
elif dtype == DType.Int:
|
||||
elif dtype_ == dtype.int64:
|
||||
return VarInt(name)
|
||||
elif dtype == DType.UnsignedInt:
|
||||
elif dtype_ == dtype.uint64:
|
||||
return VarUInt(name)
|
||||
elif dtype == DType.Float:
|
||||
elif dtype_ == dtype.float64:
|
||||
return VarFloat(name)
|
||||
raise TypeError(f"unsupported variable type `{dtype}`")
|
||||
raise TypeError(f"unsupported variable type `{dtype_}`")
|
||||
|
||||
|
||||
def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstFloat]:
|
||||
|
|
@ -48,7 +42,7 @@ def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstF
|
|||
|
||||
|
||||
def signal(
|
||||
dtype: Union[DType, Type[AllowedDtype]],
|
||||
dtype_: Union[dtype, Type[AllowedDtype]],
|
||||
*,
|
||||
data: Optional[Union[AllowedDtype, List[Tuple[float, AllowedDtype]]]] = None,
|
||||
) -> Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]:
|
||||
|
|
@ -57,7 +51,7 @@ def signal(
|
|||
Parameters
|
||||
----------
|
||||
|
||||
dtype:
|
||||
dtype_:
|
||||
Type of the signal
|
||||
|
||||
data :
|
||||
|
|
@ -67,21 +61,21 @@ def signal(
|
|||
factory: Type[Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]]
|
||||
expected_type: Type[AllowedDtype]
|
||||
|
||||
dtype = DType.convert(dtype)
|
||||
if dtype == DType.Bool:
|
||||
dtype_ = dtype.convert(dtype_)
|
||||
if dtype_ == dtype.bool_:
|
||||
factory = BoolSignal
|
||||
expected_type = bool
|
||||
elif dtype == DType.UnsignedInt:
|
||||
elif dtype_ == dtype.uint64:
|
||||
factory = UnsignedIntSignal
|
||||
expected_type = int
|
||||
elif dtype == DType.Int:
|
||||
elif dtype_ == dtype.int64:
|
||||
factory = IntSignal
|
||||
expected_type = int
|
||||
elif dtype == DType.Float:
|
||||
elif dtype_ == dtype.float64:
|
||||
factory = FloatSignal
|
||||
expected_type = float
|
||||
else:
|
||||
raise ValueError(f"unsupported dtype {dtype}")
|
||||
raise ValueError(f"unsupported dtype_ {dtype}")
|
||||
|
||||
if data is None:
|
||||
return factory.from_samples([])
|
||||
|
|
@ -92,7 +86,7 @@ def signal(
|
|||
|
||||
|
||||
__all__ = [
|
||||
"DType",
|
||||
"dtype",
|
||||
"declare_var",
|
||||
"literal",
|
||||
"signal",
|
||||
|
|
|
|||
|
|
@ -123,11 +123,11 @@ class Until(BoolExpr):
|
|||
def __init__(self, lhs: BoolExpr, rhs: BoolExpr) -> None: ...
|
||||
|
||||
@final
|
||||
class DType:
|
||||
Bool: ClassVar[DType] = ...
|
||||
Float: ClassVar[DType] = ...
|
||||
Int: ClassVar[DType] = ...
|
||||
UnsignedInt: ClassVar[DType] = ...
|
||||
class dtype: # noqa: N801
|
||||
bool_: ClassVar[dtype] = ...
|
||||
float64: ClassVar[dtype] = ...
|
||||
int64: ClassVar[dtype] = ...
|
||||
uint64: ClassVar[dtype] = ...
|
||||
|
||||
@classmethod
|
||||
def convert(cls, dtype: type[bool | int | float] | Self) -> Self: ... # noqa: Y041
|
||||
|
|
@ -143,7 +143,7 @@ class Signal(Generic[_SignalKind], Protocol):
|
|||
@property
|
||||
def end_time(self) -> float | None: ...
|
||||
@property
|
||||
def kind(self) -> type[bool | int | float]: ...
|
||||
def kind(self) -> dtype: ...
|
||||
|
||||
@final
|
||||
class BoolSignal(Signal[bool]):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue