summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c37
-rw-r--r--numpy/core/tests/test_dtype.py27
2 files changed, 63 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index fcd68e079..60642ccd3 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -1361,6 +1361,37 @@ _equivalent_units(PyObject *meta1, PyObject *meta2)
&& (data1->events == data2->events));
}
+/*
+ * Compare the subarray data for two types.
+ * Return 1 if they are the same, 0 if not.
+ */
+static int
+_equivalent_subarrays(PyArray_ArrayDescr *sub1, PyArray_ArrayDescr *sub2)
+{
+ int val;
+
+ if (sub1 == sub2) {
+ return 1;
+
+ }
+ if (sub1 == NULL || sub2 == NULL) {
+ return 0;
+ }
+
+#if defined(NPY_PY3K)
+ val = PyObject_RichCompareBool(sub1->shape, sub2->shape, Py_EQ);
+ if (val != 1 || PyErr_Occurred()) {
+#else
+ val = PyObject_Compare(sub1->shape, sub2->shape);
+ if (val != 0 || PyErr_Occurred()) {
+#endif
+ PyErr_Clear();
+ return 0;
+ }
+
+ return PyArray_EquivTypes(sub1->base, sub2->base);
+}
+
/*NUMPY_API
*
@@ -1381,6 +1412,10 @@ PyArray_EquivTypes(PyArray_Descr *typ1, PyArray_Descr *typ2)
if (PyArray_ISNBO(typ1->byteorder) != PyArray_ISNBO(typ2->byteorder)) {
return FALSE;
}
+ if (typ1->subarray || typ2->subarray) {
+ return ((typenum1 == typenum2)
+ && _equivalent_subarrays(typ1->subarray, typ2->subarray));
+ }
if (typenum1 == PyArray_VOID
|| typenum2 == PyArray_VOID) {
return ((typenum1 == typenum2)
@@ -1874,7 +1909,7 @@ array_arange(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kws) {
}
range = PyArray_ArangeObj(o_start, o_stop, o_step, typecode);
Py_XDECREF(typecode);
- return range;
+ return range;
}
/*NUMPY_API
diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py
index 8d450b39d..e9e506204 100644
--- a/numpy/core/tests/test_dtype.py
+++ b/numpy/core/tests/test_dtype.py
@@ -55,6 +55,33 @@ class TestRecord(TestCase):
self.assertRaises(TypeError, np.dtype,
dict(names=['A', 'B'], formats=set(['f8', 'i4'])))
+class TestShape(TestCase):
+ def test_equal(self):
+ """Test some data types that are equal"""
+ self.assertEqual(np.dtype('f8'), np.dtype(('f8',tuple())))
+ self.assertEqual(np.dtype('f8'), np.dtype(('f8',1)))
+ self.assertEqual(np.dtype((np.int,2)), np.dtype((np.int,(2,))))
+ self.assertEqual(np.dtype(('<f4',(3,2))), np.dtype(('<f4',(3,2))))
+ d = ([('a','f4',(1,2)),('b','f8',(3,1))],(3,2))
+ self.assertEqual(np.dtype(d), np.dtype(d))
+
+ def test_simple(self):
+ """Test some simple cases that shouldn't be equal"""
+ self.assertNotEqual(np.dtype('f8'), np.dtype(('f8',(1,))))
+ self.assertNotEqual(np.dtype(('f8',(1,))), np.dtype(('f8',(1,1))))
+ self.assertNotEqual(np.dtype(('f4',(3,2))), np.dtype(('f4',(2,3))))
+
+ def test_monster(self):
+ """Test some more complicated cases that shouldn't be equal"""
+ self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))),
+ np.dtype(([('a','f4',(1,2)), ('b','f8',(1,3))],(2,2))))
+ self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))),
+ np.dtype(([('a','f4',(2,1)), ('b','i8',(1,3))],(2,2))))
+ self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))),
+ np.dtype(([('e','f8',(1,3)), ('d','f4',(2,1))],(2,2))))
+ self.assertNotEqual(np.dtype(([('a',[('a','i4',6)],(2,1)), ('b','f8',(1,3))],(2,2))),
+ np.dtype(([('a',[('a','u4',6)],(2,1)), ('b','f8',(1,3))],(2,2))))
+
class TestSubarray(TestCase):
def test_single_subarray(self):
a = np.dtype((np.int, (2)))