diff options
author | Bas van Beek <43369155+BvB93@users.noreply.github.com> | 2021-01-22 00:44:58 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-01-21 16:44:58 -0700 |
commit | 33273e4ae770cac1ee0cb304ad5a0056bb59ad56 (patch) | |
tree | 546680d3c4778ad3bceed02e3fab73277ca46fe3 /numpy/typing/_array_like.py | |
parent | b91f3c00ee113596acd3e508a593187258291f61 (diff) | |
download | numpy-33273e4ae770cac1ee0cb304ad5a0056bb59ad56.tar.gz |
ENH: Add dtype support to the array comparison ops (#18128)
* ENH: Added `_ArrayLikeNumber`
* ENH: Added dtype support to the array comparison ops
* MAINT: Made `dtype` and `ndarray` covariant
The dtypes scalar-type and ndarrays' dtype are now covariant instead of invariant.
This change is necasary in order to ensure that all generic subclasses can be used as underlying scalar type.
* TST: Updated the comparison typing tests
* MAINT: Fixed an issue where certain `array > arraylike` operations where neglected
More specifically operations between array-likes of `timedelta64` and `ndarray`s that can be cast into `timedelta64`.
For example:
ar_i = np.array([1])
seq_m = [np.timedelta64()]
ar_i > seq_m
Diffstat (limited to 'numpy/typing/_array_like.py')
-rw-r--r-- | numpy/typing/_array_like.py | 18 |
1 files changed, 16 insertions, 2 deletions
diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py index 35413393c..133f38800 100644 --- a/numpy/typing/_array_like.py +++ b/numpy/typing/_array_like.py @@ -12,6 +12,7 @@ from numpy import ( integer, floating, complexfloating, + number, timedelta64, datetime64, object_, @@ -33,15 +34,17 @@ else: HAVE_PROTOCOL = True _T = TypeVar("_T") +_ScalarType = TypeVar("_ScalarType", bound=generic) _DType = TypeVar("_DType", bound="dtype[Any]") +_DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]") if TYPE_CHECKING or HAVE_PROTOCOL: # The `_SupportsArray` protocol only cares about the default dtype # (i.e. `dtype=None`) of the to-be returned array. # Concrete implementations of the protocol are responsible for adding # any and all remaining overloads - class _SupportsArray(Protocol[_DType]): - def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ... + class _SupportsArray(Protocol[_DType_co]): + def __array__(self, dtype: None = ...) -> ndarray[Any, _DType_co]: ... else: _SupportsArray = Any @@ -100,6 +103,10 @@ _ArrayLikeComplex_co = _ArrayLike[ "dtype[Union[bool_, integer[Any], floating[Any], complexfloating[Any, Any]]]", Union[bool, int, float, complex], ] +_ArrayLikeNumber_co = _ArrayLike[ + "dtype[Union[bool_, number[Any]]]", + Union[bool, int, float, complex], +] _ArrayLikeTD64_co = _ArrayLike[ "dtype[Union[bool_, integer[Any], timedelta64]]", Union[bool, int], @@ -116,3 +123,10 @@ _ArrayLikeBytes_co = _ArrayLike[ "dtype[bytes_]", bytes, ] + +if TYPE_CHECKING: + _ArrayND = ndarray[Any, dtype[_ScalarType]] + _ArrayOrScalar = Union[_ScalarType, _ArrayND[_ScalarType]] +else: + _ArrayND = Any + _ArrayOrScalar = Any |