summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/multiarray/methods.c17
-rw-r--r--numpy/core/tests/test_multiarray.py16
2 files changed, 33 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index cb63c7f74..cdbd0d6ae 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -356,6 +356,7 @@ PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset)
PyObject *ret = NULL;
PyObject *safe;
static PyObject *checkfunc = NULL;
+ int self_elsize, typed_elsize;
/* check that we are not reinterpreting memory containing Objects. */
if (_may_have_objects(PyArray_DESCR(self)) || _may_have_objects(typed)) {
@@ -373,6 +374,22 @@ PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset)
}
Py_DECREF(safe);
}
+ self_elsize = PyArray_ITEMSIZE(self);
+ typed_elsize = typed->elsize;
+
+ /* check that values are valid */
+ if (typed_elsize > self_elsize) {
+ PyErr_SetString(PyExc_ValueError, "new type is larger than original type");
+ return NULL;
+ }
+ if (offset < 0) {
+ PyErr_SetString(PyExc_ValueError, "offset is negative");
+ return NULL;
+ }
+ if (offset > self_elsize - typed_elsize) {
+ PyErr_SetString(PyExc_ValueError, "new type plus offset is larger than original type");
+ return NULL;
+ }
ret = PyArray_NewFromDescr_int(
Py_TYPE(self), typed,
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 36d148f55..e8353a702 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -7662,3 +7662,19 @@ def test_uintalignment_and_alignment():
# check that copy code doesn't complain in debug mode
dst = np.zeros((2,2), dtype='c8')
dst[:,1] = src[:,1] # assert in lowlevel_strided_loops fails?
+
+def test_getfield():
+ a = np.arange(32, dtype='uint16')
+ if sys.byteorder == 'little':
+ i = 0
+ j = 1
+ else:
+ i = 1
+ j = 0
+ b = a.getfield('int8', i)
+ assert_equal(b, a)
+ b = a.getfield('int8', j)
+ assert_equal(b, 0)
+ pytest.raises(ValueError, a.getfield, 'uint8', -1)
+ pytest.raises(ValueError, a.getfield, 'uint8', 16)
+ pytest.raises(ValueError, a.getfield, 'uint64', 0)