diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/scalartypes.c.src | 153 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 23 | ||||
-rw-r--r-- | numpy/core/tests/test_records.py | 6 |
3 files changed, 107 insertions, 75 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src index b91a644cd..478c5bb78 100644 --- a/numpy/core/src/multiarray/scalartypes.c.src +++ b/numpy/core/src/multiarray/scalartypes.c.src @@ -1651,33 +1651,13 @@ gentype_@name@(PyObject *self, PyObject *args, PyObject *kwds) static PyObject * voidtype_getfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds) { - PyObject *ret, *newargs; - - newargs = PyTuple_GetSlice(args, 0, 2); - if (newargs == NULL) { - return NULL; - } - ret = gentype_generic_method((PyObject *)self, newargs, kwds, "getfield"); - Py_DECREF(newargs); - if (!ret) { - return ret; - } - if (PyArray_IsScalar(ret, Generic) && \ - (!PyArray_IsScalar(ret, Void))) { - PyArray_Descr *new; - void *ptr; - if (!PyArray_ISNBO(self->descr->byteorder)) { - new = PyArray_DescrFromScalar(ret); - ptr = scalar_value(ret, new); - byte_swap_vector(ptr, 1, new->elsize); - Py_DECREF(new); - } - } - return ret; + /* Use ndarray's getfield to obtain the field safely */ + return gentype_generic_method((PyObject *)self, args, kwds, "getfield"); } static PyObject * -gentype_setfield(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), PyObject *NPY_UNUSED(kwds)) +gentype_setfield(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), + PyObject *NPY_UNUSED(kwds)) { PyErr_SetString(PyExc_TypeError, "Can't set fields in a non-void array scalar."); @@ -1687,59 +1667,75 @@ gentype_setfield(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), PyObjec static PyObject * voidtype_setfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds) { - PyArray_Descr *typecode = NULL; - int offset = 0; - PyObject *value; - PyArrayObject *src; - int mysize; - char *dptr; - static char *kwlist[] = {"value", "dtype", "offset", 0}; - - if ((self->flags & NPY_ARRAY_WRITEABLE) != NPY_ARRAY_WRITEABLE) { - PyErr_SetString(PyExc_RuntimeError, "Can't write to memory"); + /* + * We would like to use ndarray's setfield because it performs safety + * checks on the field datatypes and because it broadcasts properly. + * However, as a special case, void-scalar assignment broadcasts + * differently from ndarrays when assigning to an object field: Assignment + * to an ndarray object field broadcasts, but assignment to a void-scalar + * object-field should not, in order to allow nested ndarrays. + * These lines should then behave identically: + * + * b = np.zeros(1, dtype=[('x', 'O')]) + * b[0]['x'] = arange(3) # uses voidtype_setfield + * b['x'][0] = arange(3) # uses ndarray setitem + * + * Ndarray's setfield would try to broadcast the lhs. Instead we use + * ndarray getfield to get the field safely, then setitem to set the value + * without broadcast. Note we also want subarrays to be set properly, ie + * + * a = np.zeros(1, dtype=[('x', 'i', 5)]) + * a[0]['x'] = 1 + * + * sets all values to 1. Setitem does this. + */ + PyObject *getfield_args, *value, *arr, *meth, *arr_field, *emptytuple; + + value = PyTuple_GetItem(args, 0); + if (value == NULL) { return NULL; } - if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO&|i", kwlist, - &value, - PyArray_DescrConverter, - &typecode, &offset)) { - Py_XDECREF(typecode); + getfield_args = PyTuple_GetSlice(args, 1, 3); + if (getfield_args == NULL) { return NULL; } - mysize = Py_SIZE(self); - - if (offset < 0 || (offset + typecode->elsize) > mysize) { - PyErr_Format(PyExc_ValueError, - "Need 0 <= offset <= %d for requested type " \ - "but received offset = %d", - mysize-typecode->elsize, offset); - Py_DECREF(typecode); + /* 1. Convert to 0-d array and use getfield */ + arr = PyArray_FromScalar(self, NULL); + if (arr == NULL) { + Py_DECREF(getfield_args); return NULL; } - - dptr = self->obval + offset; - - if (typecode->type_num == NPY_OBJECT) { - PyObject *temp; - Py_INCREF(value); - NPY_COPY_PYOBJECT_PTR(&temp, dptr); - Py_XDECREF(temp); - NPY_COPY_PYOBJECT_PTR(dptr, &value); - Py_DECREF(typecode); + meth = PyObject_GetAttrString(arr, "getfield"); + if (meth == NULL) { + Py_DECREF(getfield_args); + Py_DECREF(arr); + return NULL; + } + if (kwds == NULL) { + arr_field = PyObject_CallObject(meth, getfield_args); } else { - /* Copy data from value to correct place in dptr */ - src = (PyArrayObject *)PyArray_FromAny(value, typecode, - 0, 0, NPY_ARRAY_CARRAY, NULL); - if (src == NULL) { - return NULL; - } - typecode->f->copyswap(dptr, PyArray_DATA(src), - !PyArray_ISNBO(self->descr->byteorder), - src); - Py_DECREF(src); + arr_field = PyObject_Call(meth, getfield_args, kwds); } + Py_DECREF(getfield_args); + Py_DECREF(meth); + Py_DECREF(arr); + + if(arr_field == NULL){ + return NULL; + } + + /* 2. Fill the resulting array using setitem */ + emptytuple = PyTuple_New(0); + if (PyObject_SetItem(arr_field, emptytuple, value) < 0) { + Py_DECREF(arr_field); + Py_DECREF(emptytuple); + return NULL; + } + Py_DECREF(arr_field); + Py_DECREF(emptytuple); + Py_RETURN_NONE; } @@ -2165,7 +2161,7 @@ static PyObject * voidtype_item(PyVoidScalarObject *self, Py_ssize_t n) { npy_intp m; - PyObject *flist=NULL, *fieldinfo; + PyObject *flist=NULL, *fieldind, *fieldparam, *fieldinfo, *ret; if (!(PyDataType_HASFIELDS(self->descr))) { PyErr_SetString(PyExc_IndexError, @@ -2181,9 +2177,13 @@ voidtype_item(PyVoidScalarObject *self, Py_ssize_t n) PyErr_Format(PyExc_IndexError, "invalid index (%d)", (int) n); return NULL; } - fieldinfo = PyDict_GetItem(self->descr->fields, - PyTuple_GET_ITEM(flist, n)); - return voidtype_getfield(self, fieldinfo, NULL); + /* no error checking needed: descr->names is well structured */ + fieldind = PyTuple_GET_ITEM(flist, n); + fieldparam = PyDict_GetItem(self->descr->fields, fieldind); + fieldinfo = PyTuple_GetSlice(fieldparam, 0, 2); + ret = voidtype_getfield(self, fieldinfo, NULL); + Py_DECREF(fieldinfo); + return ret; } @@ -2192,7 +2192,7 @@ static PyObject * voidtype_subscript(PyVoidScalarObject *self, PyObject *ind) { npy_intp n; - PyObject *fieldinfo; + PyObject *ret, *fieldinfo, *fieldparam; if (!(PyDataType_HASFIELDS(self->descr))) { PyErr_SetString(PyExc_IndexError, @@ -2206,11 +2206,14 @@ voidtype_subscript(PyVoidScalarObject *self, PyObject *ind) if (PyBytes_Check(ind) || PyUnicode_Check(ind)) { #endif /* look up in fields */ - fieldinfo = PyDict_GetItem(self->descr->fields, ind); - if (!fieldinfo) { + fieldparam = PyDict_GetItem(self->descr->fields, ind); + if (!fieldparam) { goto fail; } - return voidtype_getfield(self, fieldinfo, NULL); + fieldinfo = PyTuple_GetSlice(fieldparam, 0, 2); + ret = voidtype_getfield(self, fieldinfo, NULL); + Py_DECREF(fieldinfo); + return ret; } /* try to convert it to a number */ diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 74c57e18a..d392d575a 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -818,6 +818,29 @@ class TestStructured(TestCase): dat2 = np.zeros(3, [('A', 'i'), ('B', '|O')]) new2 = dat2[['B', 'A']] # TypeError? + def test_setfield(self): + # https://github.com/numpy/numpy/issues/3126 + struct_dt = np.dtype([('elem', 'i4', 5),]) + dt = np.dtype([('field', 'i4', 10),('struct', struct_dt)]) + x = np.zeros(1, dt) + x[0]['field'] = np.ones(10, dtype='i4') + x[0]['struct'] = np.ones(1, dtype=struct_dt) + assert_equal(x[0]['field'], np.ones(10, dtype='i4')) + + def test_setfield_object(self): + # make sure object field assignment with ndarray value + # on void scalar mimics setitem behavior + b = np.zeros(1, dtype=[('x', 'O')]) + # next line should work identically to b['x'][0] = np.arange(3) + b[0]['x'] = np.arange(3) + assert_equal(b[0]['x'], np.arange(3)) + + #check that broadcasting check still works + c = np.zeros(1, dtype=[('x', 'O', 5)]) + def testassign(): + c[0]['x'] = np.arange(3) + assert_raises(ValueError, testassign) + class TestBool(TestCase): def test_test_interning(self): a0 = bool_(0) diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py index b07b3c876..52cfd1868 100644 --- a/numpy/core/tests/test_records.py +++ b/numpy/core/tests/test_records.py @@ -229,6 +229,12 @@ class TestRecord(TestCase): ra = np.recarray((2,), dtype=[('x', object), ('y', float), ('z', int)]) ra[['x','y']] #TypeError? + def test_record_scalar_setitem(self): + # https://github.com/numpy/numpy/issues/3561 + rec = np.recarray(1, dtype=[('x', float, 5)]) + rec[0].x = 1 + assert_equal(rec[0].x, np.ones(5)) + def test_find_duplicate(): l1 = [1, 2, 3, 4, 5, 6] assert_(np.rec.find_duplicate(l1) == []) |