refactor(pyargus): data type name

This commit is contained in:
Anand Balakrishnan 2023-09-07 15:43:04 -07:00
parent 8093ab7c9f
commit e2cfe3da56
No known key found for this signature in database
5 changed files with 85 additions and 81 deletions

View file

@ -3,7 +3,7 @@ from __future__ import annotations
from typing import List, Optional, Tuple, Type, Union
from argus import _argus
from argus._argus import DType as DType
from argus._argus import dtype
from argus.exprs import ConstBool, ConstFloat, ConstInt, ConstUInt, VarBool, VarFloat, VarInt, VarUInt
from argus.signals import BoolSignal, FloatSignal, IntSignal, Signal, UnsignedIntSignal
@ -15,25 +15,19 @@ except AttributeError:
AllowedDtype = Union[bool, int, float]
def declare_var(name: str, dtype: Union[DType, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]:
def declare_var(name: str, dtype_: Union[dtype, Type[AllowedDtype]]) -> Union[VarBool, VarInt, VarUInt, VarFloat]:
"""Declare a variable with the given name and type"""
if isinstance(dtype, type):
if dtype == bool:
dtype = DType.Bool
elif dtype == int:
dtype = DType.Int
elif dtype == float:
dtype = DType.Float
dtype_ = dtype.convert(dtype_)
if dtype == DType.Bool:
if dtype_ == dtype.bool_:
return VarBool(name)
elif dtype == DType.Int:
elif dtype_ == dtype.int64:
return VarInt(name)
elif dtype == DType.UnsignedInt:
elif dtype_ == dtype.uint64:
return VarUInt(name)
elif dtype == DType.Float:
elif dtype_ == dtype.float64:
return VarFloat(name)
raise TypeError(f"unsupported variable type `{dtype}`")
raise TypeError(f"unsupported variable type `{dtype_}`")
def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstFloat]:
@ -48,7 +42,7 @@ def literal(value: AllowedDtype) -> Union[ConstBool, ConstInt, ConstUInt, ConstF
def signal(
dtype: Union[DType, Type[AllowedDtype]],
dtype_: Union[dtype, Type[AllowedDtype]],
*,
data: Optional[Union[AllowedDtype, List[Tuple[float, AllowedDtype]]]] = None,
) -> Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]:
@ -57,7 +51,7 @@ def signal(
Parameters
----------
dtype:
dtype_:
Type of the signal
data :
@ -67,21 +61,21 @@ def signal(
factory: Type[Union[BoolSignal, UnsignedIntSignal, IntSignal, FloatSignal]]
expected_type: Type[AllowedDtype]
dtype = DType.convert(dtype)
if dtype == DType.Bool:
dtype_ = dtype.convert(dtype_)
if dtype_ == dtype.bool_:
factory = BoolSignal
expected_type = bool
elif dtype == DType.UnsignedInt:
elif dtype_ == dtype.uint64:
factory = UnsignedIntSignal
expected_type = int
elif dtype == DType.Int:
elif dtype_ == dtype.int64:
factory = IntSignal
expected_type = int
elif dtype == DType.Float:
elif dtype_ == dtype.float64:
factory = FloatSignal
expected_type = float
else:
raise ValueError(f"unsupported dtype {dtype}")
raise ValueError(f"unsupported dtype_ {dtype}")
if data is None:
return factory.from_samples([])
@ -92,7 +86,7 @@ def signal(
__all__ = [
"DType",
"dtype",
"declare_var",
"literal",
"signal",

View file

@ -123,11 +123,11 @@ class Until(BoolExpr):
def __init__(self, lhs: BoolExpr, rhs: BoolExpr) -> None: ...
@final
class DType:
Bool: ClassVar[DType] = ...
Float: ClassVar[DType] = ...
Int: ClassVar[DType] = ...
UnsignedInt: ClassVar[DType] = ...
class dtype: # noqa: N801
bool_: ClassVar[dtype] = ...
float64: ClassVar[dtype] = ...
int64: ClassVar[dtype] = ...
uint64: ClassVar[dtype] = ...
@classmethod
def convert(cls, dtype: type[bool | int | float] | Self) -> Self: ... # noqa: Y041
@ -143,7 +143,7 @@ class Signal(Generic[_SignalKind], Protocol):
@property
def end_time(self) -> float | None: ...
@property
def kind(self) -> type[bool | int | float]: ...
def kind(self) -> dtype: ...
@final
class BoolSignal(Signal[bool]):