summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2010-10-18 17:57:31 -0700
committerPauli Virtanen <pav@iki.fi>2010-11-20 00:33:05 +0100
commitced0e77e34e3d46edd4e47178060b3abfc1f9221 (patch)
treedf3756e092b55ef5cd8c8ebed582e81c26015502
parent223f553725e41599a11313169642f064ad248059 (diff)
downloadnumpy-ced0e77e34e3d46edd4e47178060b3abfc1f9221.tar.gz
ENH: core: allow record array comparison when sub-arrays are present
-rw-r--r--numpy/core/src/multiarray/arrayobject.c43
-rw-r--r--numpy/core/tests/test_multiarray.py55
2 files changed, 97 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/arrayobject.c b/numpy/core/src/multiarray/arrayobject.c
index 587e45d26..0b59c6b59 100644
--- a/numpy/core/src/multiarray/arrayobject.c
+++ b/numpy/core/src/multiarray/arrayobject.c
@@ -827,6 +827,8 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
PyObject *key, *value, *temp2;
PyObject *op;
Py_ssize_t pos = 0;
+ intp result_ndim = PyArray_NDIM(self) > PyArray_NDIM(other) ?
+ PyArray_NDIM(self) : PyArray_NDIM(other);
op = (cmp_op == Py_EQ ? n_ops.logical_and : n_ops.logical_or);
while (PyDict_Next(self->descr->fields, &pos, &key, &value)) {
@@ -851,6 +853,45 @@ _void_compare(PyArrayObject *self, PyArrayObject *other, int cmp_op)
Py_XDECREF(res);
return NULL;
}
+
+ /*
+ * If the field type has a non-trivial shape, additional
+ * dimensions will have been appended to `a` and `b`.
+ * In that case, reduce them using `op`.
+ */
+ if (PyArray_NDIM(temp) > result_ndim) {
+ /* If the type was multidimensional, collapse that part to 1-D
+ */
+ if (PyArray_NDIM(temp) != result_ndim+1) {
+ intp dimensions[NPY_MAXDIMS];
+ PyArray_Dims newdims;
+
+ newdims.ptr = dimensions;
+ newdims.len = result_ndim+1;
+ memcpy(dimensions, PyArray_DIMS(temp),
+ sizeof(intp)*result_ndim);
+ dimensions[result_ndim] = -1;
+ temp2 = PyArray_Newshape(temp, &newdims, PyArray_ANYORDER);
+ if (temp2 == NULL) {
+ Py_DECREF(temp);
+ Py_XDECREF(res);
+ return NULL;
+ }
+ Py_DECREF(temp);
+ temp = temp2;
+ }
+ /* Reduce the extra dimension of `temp` using `op` */
+ temp2 = PyArray_GenericReduceFunction(temp, op, result_ndim,
+ PyArray_BOOL, NULL);
+ if (temp2 == NULL) {
+ Py_DECREF(temp);
+ Py_XDECREF(res);
+ return NULL;
+ }
+ Py_DECREF(temp);
+ temp = temp2;
+ }
+
if (res == NULL) {
res = temp;
}
@@ -1215,7 +1256,7 @@ array_new(PyTypeObject *subtype, PyObject *args, PyObject *kwds)
}
else if ((strides.ptr == NULL) &&
(buffer.len < (offset + (((intp)itemsize)*
- PyArray_MultiplyList(dims.ptr,
+ PyArray_MultiplyList(dims.ptr,
dims.len))))) {
PyErr_SetString(PyExc_TypeError,
"buffer is too small for " \
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index 6bb8aced9..b3bf209a1 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -269,6 +269,61 @@ class TestStructured(TestCase):
assert_equal(a['a'].shape, b['a'].shape)
assert_equal(a.T['a'].shape, a.T.copy()['a'].shape)
+
+ def test_subarray_comparison(self):
+ # Check that comparisons between record arrays with
+ # multi-dimensional field types work properly
+ a = np.rec.fromrecords(
+ [([1,2,3],'a', [[1,2],[3,4]]),([3,3,3],'b',[[0,0],[0,0]])],
+ dtype=[('a', ('f4',3)), ('b', np.object), ('c', ('i4',(2,2)))])
+ b = a.copy()
+ assert_equal(a==b, [True,True])
+ assert_equal(a!=b, [False,False])
+ b[1].b = 'c'
+ assert_equal(a==b, [True,False])
+ assert_equal(a!=b, [False,True])
+ for i in range(3):
+ b[0].a = a[0].a
+ b[0].a[i] = 5
+ assert_equal(a==b, [False,False])
+ assert_equal(a!=b, [True,True])
+ for i in range(2):
+ for j in range(2):
+ b = a.copy()
+ b[0].c[i,j] = 10
+ assert_equal(a==b, [False,True])
+ assert_equal(a!=b, [True,False])
+
+ # Check that broadcasting with a subarray works
+ a = np.array([[(0,)],[(1,)]],dtype=[('a','f8')])
+ b = np.array([(0,),(0,),(1,)],dtype=[('a','f8')])
+ assert_equal(a==b, [[True, True, False], [False, False, True]])
+ assert_equal(b==a, [[True, True, False], [False, False, True]])
+ a = np.array([[(0,)],[(1,)]],dtype=[('a','f8',(1,))])
+ b = np.array([(0,),(0,),(1,)],dtype=[('a','f8',(1,))])
+ assert_equal(a==b, [[True, True, False], [False, False, True]])
+ assert_equal(b==a, [[True, True, False], [False, False, True]])
+ a = np.array([[([0,0],)],[([1,1],)]],dtype=[('a','f8',(2,))])
+ b = np.array([([0,0],),([0,1],),([1,1],)],dtype=[('a','f8',(2,))])
+ assert_equal(a==b, [[True, False, False], [False, False, True]])
+ assert_equal(b==a, [[True, False, False], [False, False, True]])
+
+ # Check that broadcasting Fortran-style arrays with a subarray work
+ a = np.array([[([0,0],)],[([1,1],)]],dtype=[('a','f8',(2,))], order='F')
+ b = np.array([([0,0],),([0,1],),([1,1],)],dtype=[('a','f8',(2,))])
+ assert_equal(a==b, [[True, False, False], [False, False, True]])
+ assert_equal(b==a, [[True, False, False], [False, False, True]])
+
+ # Check that incompatible sub-array shapes don't result to broadcasting
+ x = np.zeros((1,), dtype=[('a', ('f4', (1,2))), ('b', 'i1')])
+ y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')])
+ assert_equal(x == y, False)
+
+ x = np.zeros((1,), dtype=[('a', ('f4', (2,1))), ('b', 'i1')])
+ y = np.zeros((1,), dtype=[('a', ('f4', (2,))), ('b', 'i1')])
+ assert_equal(x == y, False)
+
+
class TestBool(TestCase):
def test_test_interning(self):
a0 = bool_(0)