diff options
author | Travis Oliphant <oliphant@enthought.com> | 2006-06-27 18:50:37 +0000 |
---|---|---|
committer | Travis Oliphant <oliphant@enthought.com> | 2006-06-27 18:50:37 +0000 |
commit | 47c92e4224ee36090289fcafa6820f749bc3d9b1 (patch) | |
tree | 29fd98ac4b941c118dc7ce1aa5fc6cb11c8946a9 /numpy | |
parent | f3ecc6649d31ccb1e83be632c43fd7bc61701934 (diff) | |
download | numpy-47c92e4224ee36090289fcafa6820f749bc3d9b1.tar.gz |
Fix setting of sub-classes that over-ride __getitem__
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/arrayobject.c | 24 | ||||
-rw-r--r-- | numpy/core/tests/test_defmatrix.py | 21 |
2 files changed, 33 insertions, 12 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c index 59a83a04f..bd064b3f4 100644 --- a/numpy/core/src/arrayobject.c +++ b/numpy/core/src/arrayobject.c @@ -2578,11 +2578,23 @@ array_ass_sub_simple(PyArrayObject *self, PyObject *index, PyObject *op) /* Rest of standard (view-based) indexing */ - if ((tmp = (PyArrayObject*)\ - array_subscript_simple(self, index)) == NULL) { - return -1; + if (PyArray_CheckExact(self)) { + tmp = (PyArrayObject *)array_subscript_simple(self, index); + if (tmp == NULL) return -1; } - + else { + PyObject *tmp0; + tmp0 = PyObject_GetItem((PyObject *)self, index); + if (tmp0 == NULL) return -1; + if (!PyArray_Check(tmp0)) { + PyErr_SetString(PyExc_RuntimeError, + "Getitem not returning array."); + Py_DECREF(tmp0); + return -1; + } + tmp = (PyArrayObject *)tmp0; + } + if (PyArray_ISOBJECT(self) && (tmp->nd == 0)) { ret = tmp->descr->f->setitem(op, tmp->data, tmp); } @@ -8866,13 +8878,13 @@ PyArray_MapIterBind(PyArrayMapIterObject *mit, PyArrayObject *arr) /* But, be sure to do it with a true array. */ if (PyArray_CheckExact(arr)) { - sub = array_subscript(arr, mit->indexobj); + sub = array_subscript_simple(arr, mit->indexobj); } else { Py_INCREF(arr); obj = PyArray_EnsureArray((PyObject *)arr); if (obj == NULL) goto fail; - sub = array_subscript((PyArrayObject *)obj, mit->indexobj); + sub = array_subscript_simple((PyArrayObject *)obj, mit->indexobj); Py_DECREF(obj); } diff --git a/numpy/core/tests/test_defmatrix.py b/numpy/core/tests/test_defmatrix.py index b2735edca..0f5099599 100644 --- a/numpy/core/tests/test_defmatrix.py +++ b/numpy/core/tests/test_defmatrix.py @@ -4,7 +4,7 @@ import numpy.core;reload(numpy.core) from numpy.core import * restore_path() -class test_ctor(ScipyTestCase): +class test_ctor(NumpyTestCase): def check_basic(self): A = array([[1,2],[3,4]]) mA = matrix(A) @@ -23,7 +23,7 @@ class test_ctor(ScipyTestCase): mvec = matrix(vec) assert mvec.shape == (1,5) -class test_properties(ScipyTestCase): +class test_properties(NumpyTestCase): def check_sum(self): """Test whether matrix.sum(axis=1) preserves orientation. Fails in NumPy <= 0.9.6.2127. @@ -91,7 +91,7 @@ class test_properties(ScipyTestCase): assert A.sum() == matrix(2) assert A.mean() == matrix(0.5) -class test_casting(ScipyTestCase): +class test_casting(NumpyTestCase): def check_basic(self): A = arange(100).reshape(10,10) mA = matrix(A) @@ -109,7 +109,7 @@ class test_casting(ScipyTestCase): assert mC.dtype.type == complex128 assert all(mA != mB) -class test_algebra(ScipyTestCase): +class test_algebra(NumpyTestCase): def check_basic(self): import numpy.linalg as linalg @@ -132,7 +132,7 @@ class test_algebra(ScipyTestCase): assert allclose((mA + mA).A, (A + A)) assert allclose((3*mA).A, (3*A)) -class test_matrix_return(ScipyTestCase): +class test_matrix_return(NumpyTestCase): def check_instance_methods(self): a = matrix([1.0], dtype='f8') methodargs = { @@ -171,5 +171,14 @@ class test_matrix_return(ScipyTestCase): assert type(c) is matrix assert type(d) is matrix +class test_indexing(NumpyTestCase): + def check_basic(self): + x = asmatrix(zeros((3,2),float)) + y = zeros((3,1),float) + y[:,0] = [0.8,0.2,0.3] + x[:,1] = y>0.5 + assert_equal(x, [[0,1],[0,0],[0,0]]) + + if __name__ == "__main__": - ScipyTest().run() + NumpyTest().run() |