summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-04-28 19:46:58 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-04-30 22:09:51 +0200
commit3888fa81497b59a8ae33204b53da6f281d290ea0 (patch)
tree1a97811d16014806694da458ea3ce661e1b7b048 /numpy
parent235e4f32b87f8a3b93281c19bee03ac67ec32e5c (diff)
downloadnumpy-3888fa81497b59a8ae33204b53da6f281d290ea0.tar.gz
MAINT: Remove unsafe unions from `np.core.einsumfunc`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/einsumfunc.pyi18
-rw-r--r--numpy/typing/tests/data/reveal/einsumfunc.py16
2 files changed, 19 insertions, 15 deletions
diff --git a/numpy/core/einsumfunc.pyi b/numpy/core/einsumfunc.pyi
index b33aff29f..2457e8719 100644
--- a/numpy/core/einsumfunc.pyi
+++ b/numpy/core/einsumfunc.pyi
@@ -13,7 +13,6 @@ from numpy import (
_OrderKACF,
)
from numpy.typing import (
- _ArrayOrScalar,
_ArrayLikeBool_co,
_ArrayLikeUInt_co,
_ArrayLikeInt_co,
@@ -46,6 +45,11 @@ _CastingUnsafe = Literal["unsafe"]
__all__: List[str]
# TODO: Properly handle the `casting`-based combinatorics
+# TODO: We need to evaluate the content `__subscripts` in order
+# to identify whether or an array or scalar is returned. At a cursory
+# glance this seems like something that can quite easilly be done with
+# a mypy plugin.
+# Something like `is_scalar = bool(__subscripts.partition("->")[-1])`
@overload
def einsum(
__subscripts: str,
@@ -55,7 +59,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[bool_]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -65,7 +69,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[unsignedinteger[Any]]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -75,7 +79,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[signedinteger[Any]]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -85,7 +89,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[floating[Any]]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -95,7 +99,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[complexfloating[Any, Any]]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -105,7 +109,7 @@ def einsum(
out: None = ...,
order: _OrderKACF = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[Any]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
diff --git a/numpy/typing/tests/data/reveal/einsumfunc.py b/numpy/typing/tests/data/reveal/einsumfunc.py
index 18c192b0b..f1a90428d 100644
--- a/numpy/typing/tests/data/reveal/einsumfunc.py
+++ b/numpy/typing/tests/data/reveal/einsumfunc.py
@@ -10,17 +10,17 @@ AR_LIKE_U: List[str]
OUT_f: np.ndarray[Any, np.dtype[np.float64]]
-reveal_type(np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_b)) # E: Union[numpy.bool_, numpy.ndarray[Any, numpy.dtype[numpy.bool_]]
-reveal_type(np.einsum("i,i->i", AR_LIKE_u, AR_LIKE_u)) # E: Union[numpy.unsignedinteger[Any], numpy.ndarray[Any, numpy.dtype[numpy.unsignedinteger[Any]]]
-reveal_type(np.einsum("i,i->i", AR_LIKE_i, AR_LIKE_i)) # E: Union[numpy.signedinteger[Any], numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[Any]]]
-reveal_type(np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f)) # E: Union[numpy.floating[Any], numpy.ndarray[Any, numpy.dtype[numpy.floating[Any]]]
-reveal_type(np.einsum("i,i->i", AR_LIKE_c, AR_LIKE_c)) # E: Union[numpy.complexfloating[Any, Any], numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[Any, Any]]]
-reveal_type(np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_i)) # E: Union[numpy.signedinteger[Any], numpy.ndarray[Any, numpy.dtype[numpy.signedinteger[Any]]]
-reveal_type(np.einsum("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c)) # E: Union[numpy.complexfloating[Any, Any], numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[Any, Any]]]
+reveal_type(np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_b)) # E: Any
+reveal_type(np.einsum("i,i->i", AR_LIKE_u, AR_LIKE_u)) # E: Any
+reveal_type(np.einsum("i,i->i", AR_LIKE_i, AR_LIKE_i)) # E: Any
+reveal_type(np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f)) # E: Any
+reveal_type(np.einsum("i,i->i", AR_LIKE_c, AR_LIKE_c)) # E: Any
+reveal_type(np.einsum("i,i->i", AR_LIKE_b, AR_LIKE_i)) # E: Any
+reveal_type(np.einsum("i,i,i,i->i", AR_LIKE_b, AR_LIKE_u, AR_LIKE_i, AR_LIKE_c)) # E: Any
reveal_type(np.einsum("i,i->i", AR_LIKE_c, AR_LIKE_c, out=OUT_f)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]
reveal_type(np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsafe", out=OUT_f)) # E: numpy.ndarray[Any, numpy.dtype[{float64}]
-reveal_type(np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f, dtype="c16")) # E: Union[numpy.complexfloating[Any, Any], numpy.ndarray[Any, numpy.dtype[numpy.complexfloating[Any, Any]]]
+reveal_type(np.einsum("i,i->i", AR_LIKE_f, AR_LIKE_f, dtype="c16")) # E: Any
reveal_type(np.einsum("i,i->i", AR_LIKE_U, AR_LIKE_U, dtype=bool, casting="unsafe")) # E: Any
reveal_type(np.einsum_path("i,i->i", AR_LIKE_b, AR_LIKE_b)) # E: Tuple[builtins.list[Any], builtins.str]