diff options
| author | Sebastian Berg <sebastianb@nvidia.com> | 2023-04-26 14:26:59 +0200 |
|---|---|---|
| committer | Sebastian Berg <sebastianb@nvidia.com> | 2023-04-26 14:57:52 +0200 |
| commit | ca3df13ea111d08ac1b365040b45c13117510ded (patch) | |
| tree | e998db7706f42ce7fbc6f6743511c39b0a186ea4 /numpy/core | |
| parent | 9b62c3859f11094b664546e2f4a0fc92ed5c493c (diff) | |
| download | numpy-ca3df13ea111d08ac1b365040b45c13117510ded.tar.gz | |
MAINT: Fixup handling of subarray dtype in ufunc.resolve_dtypes
This is now OK to just support, we won't replace things and things
should work out for the most part (probably).
Diffstat (limited to 'numpy/core')
| -rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 6 | ||||
| -rw-r--r-- | numpy/core/tests/test_ufunc.py | 6 |
2 files changed, 3 insertions, 9 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index e13c4cd24..39e64decb 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -6611,12 +6611,6 @@ py_resolve_dtypes_generic(PyUFuncObject *ufunc, npy_bool return_context, if (dummy_arrays[i] == NULL) { goto finish; } - if (PyArray_DESCR(dummy_arrays[i]) != descr) { - PyErr_SetString(PyExc_NotImplementedError, - "dtype was replaced during array creation, the dtype is " - "unsupported currently (a subarray dtype?)."); - goto finish; - } DTypes[i] = NPY_DTYPE(descr); Py_INCREF(DTypes[i]); if (!NPY_DT_is_legacy(DTypes[i])) { diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index 71af2ccb7..f716e2104 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -2865,10 +2865,10 @@ class TestLowlevelAPIAccess: r = np.equal.resolve_dtypes((S0, S0, None)) assert r == (S0, S0, np.dtype(bool)) - # Subarray dtypes are weird and only really exist nested, they need - # the shift to full NEP 50 to be implemented nicely: + # Subarray dtypes are weird and may not work fully, we preserve them + # leading to a TypeError (currently no equal loop for void/structured) dts = np.dtype("10i") - with pytest.raises(NotImplementedError): + with pytest.raises(TypeError): np.equal.resolve_dtypes((dts, dts, None)) def test_resolve_dtypes_reduction(self): |
