diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2010-10-18 17:57:31 -0700 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-11-20 00:33:05 +0100 |
commit | ced0e77e34e3d46edd4e47178060b3abfc1f9221 (patch) | |
tree | df3756e092b55ef5cd8c8ebed582e81c26015502 | |
parent | 223f553725e41599a11313169642f064ad248059 (diff) | |
download | numpy-ced0e77e34e3d46edd4e47178060b3abfc1f9221.tar.gz |
ENH: core: allow record array comparison when sub-arrays are present
-rw-r--r-- | numpy/core/src/multiarray/arrayobject.c | 43 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 55 |
2 files changed, 97 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c index 587e45d26..0b59c6b59 100644 --- a/numpy/core/src/multiarray/arrayobject.c +++ b/numpy/core/src/multiarray/arrayobject.c @@ -827,6 +827,8 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op) PyObject *key, *value, *temp2; PyObject *op; Py_ssize_t pos = 0; + intp result_ndim = PyArray_NDIM(self) > PyArray_NDIM(other) ? + PyArray_NDIM(self) : PyArray_NDIM(other); op = (cmp_op == Py_EQ ? n_ops.logical_and : n_ops.logical_or); while (PyDict_Next(self->descr->fields, &pos, &key, &value)) { @@ -851,6 +853,45 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op) Py_XDECREF(res); return NULL; } + + /* + * If the field type has a non-trivial shape, additional + * dimensions will have been appended to `a` and `b`. + * In that case, reduce them using `op`. + */ + if (PyArray_NDIM(temp) > result_ndim) { + /* If the type was multidimensional, collapse that part to 1-D + */ + if (PyArray_NDIM(temp) != result_ndim+1) { + intp dimensions[NPY_MAXDIMS]; + PyArray_Dims newdims; + + newdims.ptr = dimensions; + newdims.len = result_ndim+1; + memcpy(dimensions, PyArray_DIMS(temp), + sizeof(intp)*result_ndim); + dimensions[result_ndim] = -1; + temp2 = PyArray_Newshape(temp, &newdims, PyArray_ANYORDER); + if (temp2 == NULL) { + Py_DECREF(temp); + Py_XDECREF(res); + return NULL; + } + Py_DECREF(temp); + temp = temp2; + } + /* Reduce the extra dimension of `temp` using `op` */ + temp2 = PyArray_GenericReduceFunction(temp, op, result_ndim, + PyArray_BOOL, NULL); + if (temp2 == NULL) { + Py_DECREF(temp); + Py_XDECREF(res); + return NULL; + } + Py_DECREF(temp); + temp = temp2; + } + if (res == NULL) { res = temp; } @@ -1215,7 +1256,7 @@ array_new(PyTypeObject *subtype, PyObject *args, PyObject *kwds) } else if ((strides.ptr == NULL) && (buffer.len < (offset + (((intp)itemsize)* - PyArray_MultiplyList(dims.ptr, + PyArray_MultiplyList(dims.ptr, dims.len))))) { PyErr_SetString(PyExc_TypeError, "buffer is too small for " \ diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index 6bb8aced9..b3bf209a1 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -269,6 +269,61 @@ class TestStructured(TestCase): assert_equal(a['a'].shape, b['a'].shape) assert_equal(a.T['a'].shape, a.T.copy()['a'].shape) + + def test_subarray_comparison(self): + # Check that comparisons between record arrays with + # multi-dimensional field types work properly + a = np.rec.fromrecords( + [([1,2,3],'a', [[1,2],[3,4]]),([3,3,3],'b',[[0,0],[0,0]])], + dtype=[('a', ('f4',3)), ('b', np.object), ('c', ('i4',(2,2)))]) + b = a.copy() + assert_equal(a==b, [True,True]) + assert_equal(a!=b, [False,False]) + b[1].b = 'c' + assert_equal(a==b, [True,False]) + assert_equal(a!=b, [False,True]) + for i in range(3): + b[0].a = a[0].a + b[0].a[i] = 5 + assert_equal(a==b, [False,False]) + assert_equal(a!=b, [True,True]) + for i in range(2): + for j in range(2): + b = a.copy() + b[0].c[i,j] = 10 + assert_equal(a==b, [False,True]) + assert_equal(a!=b, [True,False]) + + # Check that broadcasting with a subarray works + a = np.array([[(0,)],[(1,)]],dtype=[('a','f8')]) + b = np.array([(0,),(0,),(1,)],dtype=[('a','f8')]) + assert_equal(a==b, [[True, True, False], [False, False, True]]) + assert_equal(b==a, [[True, True, False], [False, False, True]]) + a = np.array([[(0,)],[(1,)]],dtype=[('a','f8',(1,))]) + b = np.array([(0,),(0,),(1,)],dtype=[('a','f8',(1,))]) + assert_equal(a==b, [[True, True, False], [False, False, True]]) + assert_equal(b==a, [[True, True, False], [False, False, True]]) + a = np.array([[([0,0],)],[([1,1],)]],dtype=[('a','f8',(2,))]) + b = np.array([([0,0],),([0,1],),([1,1],)],dtype=[('a','f8',(2,))]) + assert_equal(a==b, [[True, False, False], [False, False, True]]) + assert_equal(b==a, [[True, False, False], [False, False, True]]) + + # Check that broadcasting Fortran-style arrays with a subarray work + a = np.array([[([0,0],)],[([1,1],)]],dtype=[('a','f8',(2,))], order='F') + b = np.array([([0,0],),([0,1],),([1,1],)],dtype=[('a','f8',(2,))]) + assert_equal(a==b, [[True, False, False], [False, False, True]]) + assert_equal(b==a, [[True, False, False], [False, False, True]]) + + # Check that incompatible sub-array shapes don't result to broadcasting + x = np.zeros((1,), dtype=[('a', ('f4', (1,2))), ('b', 'i1')]) + y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')]) + assert_equal(x == y, False) + + x = np.zeros((1,), dtype=[('a', ('f4', (2,1))), ('b', 'i1')]) + y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')]) + assert_equal(x == y, False) + + class TestBool(TestCase): def test_test_interning(self): a0 = bool_(0) |