summaryrefslogtreecommitdiff
path: root/numpy/typing/_array_like.py
diff options
context:
space:
mode:
authorBas van Beek <43369155+BvB93@users.noreply.github.com>2021-01-22 00:44:58 +0100
committerGitHub <noreply@github.com>2021-01-21 16:44:58 -0700
commit33273e4ae770cac1ee0cb304ad5a0056bb59ad56 (patch)
tree546680d3c4778ad3bceed02e3fab73277ca46fe3 /numpy/typing/_array_like.py
parentb91f3c00ee113596acd3e508a593187258291f61 (diff)
downloadnumpy-33273e4ae770cac1ee0cb304ad5a0056bb59ad56.tar.gz
ENH: Add dtype support to the array comparison ops (#18128)
* ENH: Added `_ArrayLikeNumber` * ENH: Added dtype support to the array comparison ops * MAINT: Made `dtype` and `ndarray` covariant The dtypes scalar-type and ndarrays' dtype are now covariant instead of invariant. This change is necasary in order to ensure that all generic subclasses can be used as underlying scalar type. * TST: Updated the comparison typing tests * MAINT: Fixed an issue where certain `array > arraylike` operations where neglected More specifically operations between array-likes of `timedelta64` and `ndarray`s that can be cast into `timedelta64`. For example: ar_i = np.array([1]) seq_m = [np.timedelta64()] ar_i > seq_m
Diffstat (limited to 'numpy/typing/_array_like.py')
-rw-r--r--numpy/typing/_array_like.py18
1 files changed, 16 insertions, 2 deletions
diff --git a/numpy/typing/_array_like.py b/numpy/typing/_array_like.py
index 35413393c..133f38800 100644
--- a/numpy/typing/_array_like.py
+++ b/numpy/typing/_array_like.py
@@ -12,6 +12,7 @@ from numpy import (
integer,
floating,
complexfloating,
+ number,
timedelta64,
datetime64,
object_,
@@ -33,15 +34,17 @@ else:
HAVE_PROTOCOL = True
_T = TypeVar("_T")
+_ScalarType = TypeVar("_ScalarType", bound=generic)
_DType = TypeVar("_DType", bound="dtype[Any]")
+_DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]")
if TYPE_CHECKING or HAVE_PROTOCOL:
# The `_SupportsArray` protocol only cares about the default dtype
# (i.e. `dtype=None`) of the to-be returned array.
# Concrete implementations of the protocol are responsible for adding
# any and all remaining overloads
- class _SupportsArray(Protocol[_DType]):
- def __array__(self, dtype: None = ...) -> ndarray[Any, _DType]: ...
+ class _SupportsArray(Protocol[_DType_co]):
+ def __array__(self, dtype: None = ...) -> ndarray[Any, _DType_co]: ...
else:
_SupportsArray = Any
@@ -100,6 +103,10 @@ _ArrayLikeComplex_co = _ArrayLike[
"dtype[Union[bool_, integer[Any], floating[Any], complexfloating[Any, Any]]]",
Union[bool, int, float, complex],
]
+_ArrayLikeNumber_co = _ArrayLike[
+ "dtype[Union[bool_, number[Any]]]",
+ Union[bool, int, float, complex],
+]
_ArrayLikeTD64_co = _ArrayLike[
"dtype[Union[bool_, integer[Any], timedelta64]]",
Union[bool, int],
@@ -116,3 +123,10 @@ _ArrayLikeBytes_co = _ArrayLike[
"dtype[bytes_]",
bytes,
]
+
+if TYPE_CHECKING:
+ _ArrayND = ndarray[Any, dtype[_ScalarType]]
+ _ArrayOrScalar = Union[_ScalarType, _ArrayND[_ScalarType]]
+else:
+ _ArrayND = Any
+ _ArrayOrScalar = Any