summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/scalartypes.c.src153
-rw-r--r--numpy/core/tests/test_multiarray.py23
-rw-r--r--numpy/core/tests/test_records.py6
3 files changed, 107 insertions, 75 deletions
diff --git a/numpy/core/src/multiarray/scalartypes.c.src b/numpy/core/src/multiarray/scalartypes.c.src
index b91a644cd..478c5bb78 100644
--- a/numpy/core/src/multiarray/scalartypes.c.src
+++ b/numpy/core/src/multiarray/scalartypes.c.src
@@ -1651,33 +1651,13 @@ gentype_@name@(PyObject *self, PyObject *args, PyObject *kwds)
static PyObject *
voidtype_getfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds)
{
- PyObject *ret, *newargs;
-
- newargs = PyTuple_GetSlice(args, 0, 2);
- if (newargs == NULL) {
- return NULL;
- }
- ret = gentype_generic_method((PyObject *)self, newargs, kwds, "getfield");
- Py_DECREF(newargs);
- if (!ret) {
- return ret;
- }
- if (PyArray_IsScalar(ret, Generic) && \
- (!PyArray_IsScalar(ret, Void))) {
- PyArray_Descr *new;
- void *ptr;
- if (!PyArray_ISNBO(self->descr->byteorder)) {
- new = PyArray_DescrFromScalar(ret);
- ptr = scalar_value(ret, new);
- byte_swap_vector(ptr, 1, new->elsize);
- Py_DECREF(new);
- }
- }
- return ret;
+ /* Use ndarray's getfield to obtain the field safely */
+ return gentype_generic_method((PyObject *)self, args, kwds, "getfield");
}
static PyObject *
-gentype_setfield(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), PyObject *NPY_UNUSED(kwds))
+gentype_setfield(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args),
+ PyObject *NPY_UNUSED(kwds))
{
PyErr_SetString(PyExc_TypeError,
"Can't set fields in a non-void array scalar.");
@@ -1687,59 +1667,75 @@ gentype_setfield(PyObject *NPY_UNUSED(self), PyObject *NPY_UNUSED(args), PyObjec
static PyObject *
voidtype_setfield(PyVoidScalarObject *self, PyObject *args, PyObject *kwds)
{
- PyArray_Descr *typecode = NULL;
- int offset = 0;
- PyObject *value;
- PyArrayObject *src;
- int mysize;
- char *dptr;
- static char *kwlist[] = {"value", "dtype", "offset", 0};
-
- if ((self->flags & NPY_ARRAY_WRITEABLE) != NPY_ARRAY_WRITEABLE) {
- PyErr_SetString(PyExc_RuntimeError, "Can't write to memory");
+ /*
+ * We would like to use ndarray's setfield because it performs safety
+ * checks on the field datatypes and because it broadcasts properly.
+ * However, as a special case, void-scalar assignment broadcasts
+ * differently from ndarrays when assigning to an object field: Assignment
+ * to an ndarray object field broadcasts, but assignment to a void-scalar
+ * object-field should not, in order to allow nested ndarrays.
+ * These lines should then behave identically:
+ *
+ * b = np.zeros(1, dtype=[('x', 'O')])
+ * b[0]['x'] = arange(3) # uses voidtype_setfield
+ * b['x'][0] = arange(3) # uses ndarray setitem
+ *
+ * Ndarray's setfield would try to broadcast the lhs. Instead we use
+ * ndarray getfield to get the field safely, then setitem to set the value
+ * without broadcast. Note we also want subarrays to be set properly, ie
+ *
+ * a = np.zeros(1, dtype=[('x', 'i', 5)])
+ * a[0]['x'] = 1
+ *
+ * sets all values to 1. Setitem does this.
+ */
+ PyObject *getfield_args, *value, *arr, *meth, *arr_field, *emptytuple;
+
+ value = PyTuple_GetItem(args, 0);
+ if (value == NULL) {
return NULL;
}
- if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO&|i", kwlist,
- &value,
- PyArray_DescrConverter,
- &typecode, &offset)) {
- Py_XDECREF(typecode);
+ getfield_args = PyTuple_GetSlice(args, 1, 3);
+ if (getfield_args == NULL) {
return NULL;
}
- mysize = Py_SIZE(self);
-
- if (offset < 0 || (offset + typecode->elsize) > mysize) {
- PyErr_Format(PyExc_ValueError,
- "Need 0 <= offset <= %d for requested type " \
- "but received offset = %d",
- mysize-typecode->elsize, offset);
- Py_DECREF(typecode);
+ /* 1. Convert to 0-d array and use getfield */
+ arr = PyArray_FromScalar(self, NULL);
+ if (arr == NULL) {
+ Py_DECREF(getfield_args);
return NULL;
}
-
- dptr = self->obval + offset;
-
- if (typecode->type_num == NPY_OBJECT) {
- PyObject *temp;
- Py_INCREF(value);
- NPY_COPY_PYOBJECT_PTR(&temp, dptr);
- Py_XDECREF(temp);
- NPY_COPY_PYOBJECT_PTR(dptr, &value);
- Py_DECREF(typecode);
+ meth = PyObject_GetAttrString(arr, "getfield");
+ if (meth == NULL) {
+ Py_DECREF(getfield_args);
+ Py_DECREF(arr);
+ return NULL;
+ }
+ if (kwds == NULL) {
+ arr_field = PyObject_CallObject(meth, getfield_args);
}
else {
- /* Copy data from value to correct place in dptr */
- src = (PyArrayObject *)PyArray_FromAny(value, typecode,
- 0, 0, NPY_ARRAY_CARRAY, NULL);
- if (src == NULL) {
- return NULL;
- }
- typecode->f->copyswap(dptr, PyArray_DATA(src),
- !PyArray_ISNBO(self->descr->byteorder),
- src);
- Py_DECREF(src);
+ arr_field = PyObject_Call(meth, getfield_args, kwds);
}
+ Py_DECREF(getfield_args);
+ Py_DECREF(meth);
+ Py_DECREF(arr);
+
+ if(arr_field == NULL){
+ return NULL;
+ }
+
+ /* 2. Fill the resulting array using setitem */
+ emptytuple = PyTuple_New(0);
+ if (PyObject_SetItem(arr_field, emptytuple, value) < 0) {
+ Py_DECREF(arr_field);
+ Py_DECREF(emptytuple);
+ return NULL;
+ }
+ Py_DECREF(arr_field);
+ Py_DECREF(emptytuple);
+
Py_RETURN_NONE;
}
@@ -2165,7 +2161,7 @@ static PyObject *
voidtype_item(PyVoidScalarObject *self, Py_ssize_t n)
{
npy_intp m;
- PyObject *flist=NULL, *fieldinfo;
+ PyObject *flist=NULL, *fieldind, *fieldparam, *fieldinfo, *ret;
if (!(PyDataType_HASFIELDS(self->descr))) {
PyErr_SetString(PyExc_IndexError,
@@ -2181,9 +2177,13 @@ voidtype_item(PyVoidScalarObject *self, Py_ssize_t n)
PyErr_Format(PyExc_IndexError, "invalid index (%d)", (int) n);
return NULL;
}
- fieldinfo = PyDict_GetItem(self->descr->fields,
- PyTuple_GET_ITEM(flist, n));
- return voidtype_getfield(self, fieldinfo, NULL);
+ /* no error checking needed: descr->names is well structured */
+ fieldind = PyTuple_GET_ITEM(flist, n);
+ fieldparam = PyDict_GetItem(self->descr->fields, fieldind);
+ fieldinfo = PyTuple_GetSlice(fieldparam, 0, 2);
+ ret = voidtype_getfield(self, fieldinfo, NULL);
+ Py_DECREF(fieldinfo);
+ return ret;
}
@@ -2192,7 +2192,7 @@ static PyObject *
voidtype_subscript(PyVoidScalarObject *self, PyObject *ind)
{
npy_intp n;
- PyObject *fieldinfo;
+ PyObject *ret, *fieldinfo, *fieldparam;
if (!(PyDataType_HASFIELDS(self->descr))) {
PyErr_SetString(PyExc_IndexError,
@@ -2206,11 +2206,14 @@ voidtype_subscript(PyVoidScalarObject *self, PyObject *ind)
if (PyBytes_Check(ind) || PyUnicode_Check(ind)) {
#endif
/* look up in fields */
- fieldinfo = PyDict_GetItem(self->descr->fields, ind);
- if (!fieldinfo) {
+ fieldparam = PyDict_GetItem(self->descr->fields, ind);
+ if (!fieldparam) {
goto fail;
}
- return voidtype_getfield(self, fieldinfo, NULL);
+ fieldinfo = PyTuple_GetSlice(fieldparam, 0, 2);
+ ret = voidtype_getfield(self, fieldinfo, NULL);
+ Py_DECREF(fieldinfo);
+ return ret;
}
/* try to convert it to a number */
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 74c57e18a..d392d575a 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -818,6 +818,29 @@ class TestStructured(TestCase):
dat2 = np.zeros(3, [('A', 'i'), ('B', '|O')])
new2 = dat2[['B', 'A']] # TypeError?
+ def test_setfield(self):
+ # https://github.com/numpy/numpy/issues/3126
+ struct_dt = np.dtype([('elem', 'i4', 5),])
+ dt = np.dtype([('field', 'i4', 10),('struct', struct_dt)])
+ x = np.zeros(1, dt)
+ x[0]['field'] = np.ones(10, dtype='i4')
+ x[0]['struct'] = np.ones(1, dtype=struct_dt)
+ assert_equal(x[0]['field'], np.ones(10, dtype='i4'))
+
+ def test_setfield_object(self):
+ # make sure object field assignment with ndarray value
+ # on void scalar mimics setitem behavior
+ b = np.zeros(1, dtype=[('x', 'O')])
+ # next line should work identically to b['x'][0] = np.arange(3)
+ b[0]['x'] = np.arange(3)
+ assert_equal(b[0]['x'], np.arange(3))
+
+ #check that broadcasting check still works
+ c = np.zeros(1, dtype=[('x', 'O', 5)])
+ def testassign():
+ c[0]['x'] = np.arange(3)
+ assert_raises(ValueError, testassign)
+
class TestBool(TestCase):
def test_test_interning(self):
a0 = bool_(0)
diff --git a/numpy/core/tests/test_records.py b/numpy/core/tests/test_records.py
index b07b3c876..52cfd1868 100644
--- a/numpy/core/tests/test_records.py
+++ b/numpy/core/tests/test_records.py
@@ -229,6 +229,12 @@ class TestRecord(TestCase):
ra = np.recarray((2,), dtype=[('x', object), ('y', float), ('z', int)])
ra[['x','y']] #TypeError?
+ def test_record_scalar_setitem(self):
+ # https://github.com/numpy/numpy/issues/3561
+ rec = np.recarray(1, dtype=[('x', float, 5)])
+ rec[0].x = 1
+ assert_equal(rec[0].x, np.ones(5))
+
def test_find_duplicate():
l1 = [1, 2, 3, 4, 5, 6]
assert_(np.rec.find_duplicate(l1) == [])