diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 17 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 16 |
2 files changed, 33 insertions, 0 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index cb63c7f74..cdbd0d6ae 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -356,6 +356,7 @@ PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset) PyObject *ret = NULL; PyObject *safe; static PyObject *checkfunc = NULL; + int self_elsize, typed_elsize; /* check that we are not reinterpreting memory containing Objects. */ if (_may_have_objects(PyArray_DESCR(self)) || _may_have_objects(typed)) { @@ -373,6 +374,22 @@ PyArray_GetField(PyArrayObject *self, PyArray_Descr *typed, int offset) } Py_DECREF(safe); } + self_elsize = PyArray_ITEMSIZE(self); + typed_elsize = typed->elsize; + + /* check that values are valid */ + if (typed_elsize > self_elsize) { + PyErr_SetString(PyExc_ValueError, "new type is larger than original type"); + return NULL; + } + if (offset < 0) { + PyErr_SetString(PyExc_ValueError, "offset is negative"); + return NULL; + } + if (offset > self_elsize - typed_elsize) { + PyErr_SetString(PyExc_ValueError, "new type plus offset is larger than original type"); + return NULL; + } ret = PyArray_NewFromDescr_int( Py_TYPE(self), typed, diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 36d148f55..e8353a702 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -7662,3 +7662,19 @@ def test_uintalignment_and_alignment(): # check that copy code doesn't complain in debug mode dst = np.zeros((2,2), dtype='c8') dst[:,1] = src[:,1] # assert in lowlevel_strided_loops fails? + +def test_getfield(): + a = np.arange(32, dtype='uint16') + if sys.byteorder == 'little': + i = 0 + j = 1 + else: + i = 1 + j = 0 + b = a.getfield('int8', i) + assert_equal(b, a) + b = a.getfield('int8', j) + assert_equal(b, 0) + pytest.raises(ValueError, a.getfield, 'uint8', -1) + pytest.raises(ValueError, a.getfield, 'uint8', 16) + pytest.raises(ValueError, a.getfield, 'uint64', 0) |