diff options
author | Mark Wiebe <mwwiebe@gmail.com> | 2010-10-20 16:23:36 -0700 |
---|---|---|
committer | Pauli Virtanen <pav@iki.fi> | 2010-10-31 18:29:21 +0100 |
commit | 50125047593acd877f72da200a3e49343ca51cdd (patch) | |
tree | 1cebed6d39f2cdb2dca4942ed1cd90c7da73b5f6 /numpy | |
parent | 33b3e601043a66491d468a5ba7485181ca7506cc (diff) | |
download | numpy-50125047593acd877f72da200a3e49343ca51cdd.tar.gz |
BUG: core: handle sub-arrays in dtype comparisons
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 37 | ||||
-rw-r--r-- | numpy/core/tests/test_dtype.py | 27 |
2 files changed, 63 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index fcd68e079..60642ccd3 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -1361,6 +1361,37 @@ _equivalent_units(PyObject *meta1, PyObject *meta2) && (data1->events == data2->events)); } +/* + * Compare the subarray data for two types. + * Return 1 if they are the same, 0 if not. + */ +static int +_equivalent_subarrays(PyArray_ArrayDescr *sub1, PyArray_ArrayDescr *sub2) +{ + int val; + + if (sub1 == sub2) { + return 1; + + } + if (sub1 == NULL || sub2 == NULL) { + return 0; + } + +#if defined(NPY_PY3K) + val = PyObject_RichCompareBool(sub1->shape, sub2->shape, Py_EQ); + if (val != 1 || PyErr_Occurred()) { +#else + val = PyObject_Compare(sub1->shape, sub2->shape); + if (val != 0 || PyErr_Occurred()) { +#endif + PyErr_Clear(); + return 0; + } + + return PyArray_EquivTypes(sub1->base, sub2->base); +} + /*NUMPY_API * @@ -1381,6 +1412,10 @@ PyArray_EquivTypes(PyArray_Descr *typ1, PyArray_Descr *typ2) if (PyArray_ISNBO(typ1->byteorder) != PyArray_ISNBO(typ2->byteorder)) { return FALSE; } + if (typ1->subarray || typ2->subarray) { + return ((typenum1 == typenum2) + && _equivalent_subarrays(typ1->subarray, typ2->subarray)); + } if (typenum1 == PyArray_VOID || typenum2 == PyArray_VOID) { return ((typenum1 == typenum2) @@ -1874,7 +1909,7 @@ array_arange(PyObject *NPY_UNUSED(ignored), PyObject *args, PyObject *kws) { } range = PyArray_ArangeObj(o_start, o_stop, o_step, typecode); Py_XDECREF(typecode); - return range; + return range; } /*NUMPY_API diff --git a/numpy/core/tests/test_dtype.py b/numpy/core/tests/test_dtype.py index 8d450b39d..e9e506204 100644 --- a/numpy/core/tests/test_dtype.py +++ b/numpy/core/tests/test_dtype.py @@ -55,6 +55,33 @@ class TestRecord(TestCase): self.assertRaises(TypeError, np.dtype, dict(names=['A', 'B'], formats=set(['f8', 'i4']))) +class TestShape(TestCase): + def test_equal(self): + """Test some data types that are equal""" + self.assertEqual(np.dtype('f8'), np.dtype(('f8',tuple()))) + self.assertEqual(np.dtype('f8'), np.dtype(('f8',1))) + self.assertEqual(np.dtype((np.int,2)), np.dtype((np.int,(2,)))) + self.assertEqual(np.dtype(('<f4',(3,2))), np.dtype(('<f4',(3,2)))) + d = ([('a','f4',(1,2)),('b','f8',(3,1))],(3,2)) + self.assertEqual(np.dtype(d), np.dtype(d)) + + def test_simple(self): + """Test some simple cases that shouldn't be equal""" + self.assertNotEqual(np.dtype('f8'), np.dtype(('f8',(1,)))) + self.assertNotEqual(np.dtype(('f8',(1,))), np.dtype(('f8',(1,1)))) + self.assertNotEqual(np.dtype(('f4',(3,2))), np.dtype(('f4',(2,3)))) + + def test_monster(self): + """Test some more complicated cases that shouldn't be equal""" + self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))), + np.dtype(([('a','f4',(1,2)), ('b','f8',(1,3))],(2,2)))) + self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))), + np.dtype(([('a','f4',(2,1)), ('b','i8',(1,3))],(2,2)))) + self.assertNotEqual(np.dtype(([('a','f4',(2,1)), ('b','f8',(1,3))],(2,2))), + np.dtype(([('e','f8',(1,3)), ('d','f4',(2,1))],(2,2)))) + self.assertNotEqual(np.dtype(([('a',[('a','i4',6)],(2,1)), ('b','f8',(1,3))],(2,2))), + np.dtype(([('a',[('a','u4',6)],(2,1)), ('b','f8',(1,3))],(2,2)))) + class TestSubarray(TestCase): def test_single_subarray(self): a = np.dtype((np.int, (2))) |