summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-04-26 14:26:59 +0200
committerSebastian Berg <sebastianb@nvidia.com>2023-04-26 14:57:52 +0200
commitca3df13ea111d08ac1b365040b45c13117510ded (patch)
treee998db7706f42ce7fbc6f6743511c39b0a186ea4 /numpy/core
parent9b62c3859f11094b664546e2f4a0fc92ed5c493c (diff)
downloadnumpy-ca3df13ea111d08ac1b365040b45c13117510ded.tar.gz
MAINT: Fixup handling of subarray dtype in ufunc.resolve_dtypes
This is now OK to just support, we won't replace things and things should work out for the most part (probably).
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/umath/ufunc_object.c6
-rw-r--r--numpy/core/tests/test_ufunc.py6
2 files changed, 3 insertions, 9 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index e13c4cd24..39e64decb 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -6611,12 +6611,6 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context,
if (dummy_arrays[i] == NULL) {
goto finish;
}
- if (PyArray_DESCR(dummy_arrays[i]) != descr) {
- PyErr_SetString(PyExc_NotImplementedError,
- "dtype was replaced during array creation, the dtype is "
- "unsupported currently (a subarray dtype?).");
- goto finish;
- }
DTypes[i] = NPY_DTYPE(descr);
Py_INCREF(DTypes[i]);
if (!NPY_DT_is_legacy(DTypes[i])) {
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index 71af2ccb7..f716e2104 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -2865,10 +2865,10 @@ class TestLowlevelAPIAccess:
r = np.equal.resolve_dtypes((S0, S0, None))
assert r == (S0, S0, np.dtype(bool))
- # Subarray dtypes are weird and only really exist nested, they need
- # the shift to full NEP 50 to be implemented nicely:
+ # Subarray dtypes are weird and may not work fully, we preserve them
+ # leading to a TypeError (currently no equal loop for void/structured)
dts = np.dtype("10i")
- with pytest.raises(NotImplementedError):
+ with pytest.raises(TypeError):
np.equal.resolve_dtypes((dts, dts, None))
def test_resolve_dtypes_reduction(self):