summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorEric Wieser <wieser.eric@gmail.com>2017-03-23 12:23:26 +0000
committerEric Wieser <wieser.eric@gmail.com>2017-11-02 09:06:46 -0700
commit2bc15496aa2d1717aeb0d2be3c882eb1982215f8 (patch)
tree69797bd8c9f396c7ae5dc89a51d844262353a17d
parent103f23cfbf2fe89cb3d77cd1f730fef7ec9ec112 (diff)
downloadnumpy-2bc15496aa2d1717aeb0d2be3c882eb1982215f8.tar.gz
BUG: Passing an incorrect type to dtype.__getitem__ should raise TypeError
Also allows any object implementing __index__ to be passed to dtype.__getitem__
-rw-r--r--doc/release/1.14.0-notes.rst4
-rw-r--r--numpy/core/src/multiarray/descriptor.c15
-rw-r--r--numpy/core/tests/test_dtype.py6
-rw-r--r--numpy/core/tests/test_multiarray.py2
-rw-r--r--numpy/core/tests/test_regression.py2
5 files changed, 20 insertions, 9 deletions
diff --git a/doc/release/1.14.0-notes.rst b/doc/release/1.14.0-notes.rst
index 0aeeadd40..6feea0e76 100644
--- a/doc/release/1.14.0-notes.rst
+++ b/doc/release/1.14.0-notes.rst
@@ -177,6 +177,10 @@ The ufunc ``isnat`` used to raise a ``ValueError`` when it was not passed
variables of type ``datetime`` or ``timedelta``. This has been changed to
raising a ``TypeError``.
+``dtype.__getitem__`` raises ``TypeError`` when passed wrong type
+-----------------------------------------------------------------
+When indexed with a float, the dtype object used to raise ``ValueError``.
+
C API changes
=============
diff --git a/numpy/core/src/multiarray/descriptor.c b/numpy/core/src/multiarray/descriptor.c
index 49f086d21..b6f9a4103 100644
--- a/numpy/core/src/multiarray/descriptor.c
+++ b/numpy/core/src/multiarray/descriptor.c
@@ -3825,18 +3825,19 @@ descr_subscript(PyArray_Descr *self, PyObject *op)
#endif
return _subscript_by_name(self, op);
}
- else if (PyInt_Check(op)) {
+ else {
Py_ssize_t i = PyArray_PyIntAsIntp(op);
- if (PyErr_Occurred()) {
+ if (error_converting(i)) {
+ /* if converting to an int gives a type error, adjust the message */
+ PyObject *err = PyErr_Occurred();
+ if (PyErr_GivenExceptionMatches(err, PyExc_TypeError)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Field key must be an integer, string, or unicode.");
+ }
return NULL;
}
return _subscript_by_index(self, i);
}
- else {
- PyErr_SetString(PyExc_ValueError,
- "Field key must be an integer, string, or unicode.");
- return NULL;
- }
}
static PySequenceMethods descr_as_sequence = {
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 7f5ab2c9d..b48983e2e 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -2,6 +2,7 @@ from __future__ import division, absolute_import, print_function
import pickle
import sys
+import operator
import numpy as np
from numpy.core.test_rational import rational
@@ -306,6 +307,11 @@ class TestRecord(object):
assert_dtype_equal(dt[-2], dt[0])
assert_raises(IndexError, lambda: dt[-3])
+ assert_raises(TypeError, operator.getitem, dt, 3.0)
+ assert_raises(TypeError, operator.getitem, dt, [])
+
+ assert_equal(dt[1], dt[np.int8(1)])
+
class TestSubarray(object):
def test_single_subarray(self):
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index f6a5b4983..a02075a1e 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -4606,7 +4606,7 @@ class TestRecord(object):
assert_raises(TypeError, np.dtype, [(('b', b'a'), int)])
dt = np.dtype([((b'a', 'b'), int)])
- assert_raises(ValueError, dt.__getitem__, b'a')
+ assert_raises(TypeError, dt.__getitem__, b'a')
x = np.array([(1,), (2,), (3,)], dtype=dt)
assert_raises(IndexError, x.__getitem__, b'a')
diff --git a/numpy/core/tests/test_regression.py b/numpy/core/tests/test_regression.py
index f791f6725..aac752037 100644
--- a/numpy/core/tests/test_regression.py
+++ b/numpy/core/tests/test_regression.py
@@ -1352,7 +1352,7 @@ class TestRegression(object):
dt = np.dtype([('f1', np.uint)])
assert_raises(KeyError, dt.__getitem__, "f2")
assert_raises(IndexError, dt.__getitem__, 1)
- assert_raises(ValueError, dt.__getitem__, 0.0)
+ assert_raises(TypeError, dt.__getitem__, 0.0)
def test_lexsort_buffer_length(self):
# Ticket #1217, don't segfault.