diff options
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/src/multiarray/descriptor.c | 6 | ||||
| -rw-r--r-- | numpy/core/tests/test_dtype.py | 13 |
2 files changed, 16 insertions, 3 deletions
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c index a23ee6d2c..4955062c0 100644 --- a/numpy/core/src/multiarray/descriptor.c +++ b/numpy/core/src/multiarray/descriptor.c @@ -2164,6 +2164,7 @@ arraydescr_names_set( N = PyTuple_GET_SIZE(self->names); if (!PySequence_Check(val) || PyObject_Size((PyObject *)val) != N) { + /* Should be a TypeError, but this should be deprecated anyway. */ PyErr_Format(PyExc_ValueError, "must replace all names at once with a sequence of length %d", N); @@ -2172,16 +2173,17 @@ arraydescr_names_set( /* Make sure all entries are strings */ for (i = 0; i < N; i++) { PyObject *item; - int valid = 1; + int valid; item = PySequence_GetItem(val, i); valid = PyUnicode_Check(item); - Py_DECREF(item); if (!valid) { PyErr_Format(PyExc_ValueError, "item #%d of names is of type %s and not string", i, Py_TYPE(item)->tp_name); + Py_DECREF(item); return -1; } + Py_DECREF(item); } /* Invalidate cached hash value */ self->hash = -1; diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 356b53df9..32e2c6842 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -223,7 +223,8 @@ class TestRecord: assert refcounts == refcounts_new def test_mutate(self): - # Mutating a dtype should reset the cached hash value + # Mutating a dtype should reset the cached hash value. + # NOTE: Mutating should be deprecated, but new API added to replace it. a = np.dtype([('yo', int)]) b = np.dtype([('yo', int)]) c = np.dtype([('ye', int)]) @@ -237,6 +238,16 @@ class TestRecord: assert_dtype_equal(a, b) assert_dtype_not_equal(a, c) + def test_mutate_error(self): + # NOTE: Mutating should be deprecated, but new API added to replace it. + a = np.dtype("i,i") + + with pytest.raises(ValueError, match="must replace all names at once"): + a.names = ["f0"] + + with pytest.raises(ValueError, match=".*and not string"): + a.names = ["f0", b"not a unicode name"] + def test_not_lists(self): """Test if an appropriate exception is raised when passing bad values to the dtype constructor. |
