summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorBas van Beek <b.f.van.beek@vu.nl>2021-04-28 19:46:58 +0200
committerBas van Beek <b.f.van.beek@vu.nl>2021-04-30 22:09:51 +0200
commit3888fa81497b59a8ae33204b53da6f281d290ea0 (patch)
tree1a97811d16014806694da458ea3ce661e1b7b048 /numpy/core
parent235e4f32b87f8a3b93281c19bee03ac67ec32e5c (diff)
downloadnumpy-3888fa81497b59a8ae33204b53da6f281d290ea0.tar.gz
MAINT: Remove unsafe unions from `np.core.einsumfunc`
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/einsumfunc.pyi18
1 files changed, 11 insertions, 7 deletions
diff --git a/numpy/core/einsumfunc.pyi b/numpy/core/einsumfunc.pyi
index b33aff29f..2457e8719 100644
--- a/numpy/core/einsumfunc.pyi
+++ b/numpy/core/einsumfunc.pyi
@@ -13,7 +13,6 @@ from numpy import (
_OrderKACF,
)
from numpy.typing import (
- _ArrayOrScalar,
_ArrayLikeBool_co,
_ArrayLikeUInt_co,
_ArrayLikeInt_co,
@@ -46,6 +45,11 @@ _CastingUnsafe = Literal["unsafe"]
__all__: List[str]
# TODO: Properly handle the `casting`-based combinatorics
+# TODO: We need to evaluate the content `__subscripts` in order
+# to identify whether or an array or scalar is returned. At a cursory
+# glance this seems like something that can quite easilly be done with
+# a mypy plugin.
+# Something like `is_scalar = bool(__subscripts.partition("->")[-1])`
@overload
def einsum(
__subscripts: str,
@@ -55,7 +59,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[bool_]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -65,7 +69,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[unsignedinteger[Any]]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -75,7 +79,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[signedinteger[Any]]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -85,7 +89,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[floating[Any]]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -95,7 +99,7 @@ def einsum(
order: _OrderKACF = ...,
casting: _CastingSafe = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[complexfloating[Any, Any]]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,
@@ -105,7 +109,7 @@ def einsum(
out: None = ...,
order: _OrderKACF = ...,
optimize: _OptimizeKind = ...,
-) -> _ArrayOrScalar[Any]: ...
+) -> Any: ...
@overload
def einsum(
__subscripts: str,