summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorTravis Oliphant <oliphant@enthought.com>2006-06-27 18:50:37 +0000
committerTravis Oliphant <oliphant@enthought.com>2006-06-27 18:50:37 +0000
commit47c92e4224ee36090289fcafa6820f749bc3d9b1 (patch)
tree29fd98ac4b941c118dc7ce1aa5fc6cb11c8946a9 /numpy
parentf3ecc6649d31ccb1e83be632c43fd7bc61701934 (diff)
downloadnumpy-47c92e4224ee36090289fcafa6820f749bc3d9b1.tar.gz
Fix setting of sub-classes that over-ride __getitem__
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/arrayobject.c24
-rw-r--r--numpy/core/tests/test_defmatrix.py21
2 files changed, 33 insertions, 12 deletions
diff --git a/numpy/core/src/arrayobject.c b/numpy/core/src/arrayobject.c
index 59a83a04f..bd064b3f4 100644
--- a/numpy/core/src/arrayobject.c
+++ b/numpy/core/src/arrayobject.c
@@ -2578,11 +2578,23 @@ array_ass_sub_simple(PyArrayObject *self, PyObject *index, PyObject *op)
/* Rest of standard (view-based) indexing */
- if ((tmp = (PyArrayObject*)\
- array_subscript_simple(self, index)) == NULL) {
- return -1;
+ if (PyArray_CheckExact(self)) {
+ tmp = (PyArrayObject *)array_subscript_simple(self, index);
+ if (tmp == NULL) return -1;
}
-
+ else {
+ PyObject *tmp0;
+ tmp0 = PyObject_GetItem((PyObject *)self, index);
+ if (tmp0 == NULL) return -1;
+ if (!PyArray_Check(tmp0)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Getitem not returning array.");
+ Py_DECREF(tmp0);
+ return -1;
+ }
+ tmp = (PyArrayObject *)tmp0;
+ }
+
if (PyArray_ISOBJECT(self) && (tmp->nd == 0)) {
ret = tmp->descr->f->setitem(op, tmp->data, tmp);
}
@@ -8866,13 +8878,13 @@ PyArray_MapIterBind(PyArrayMapIterObject *mit, PyArrayObject *arr)
/* But, be sure to do it with a true array.
*/
if (PyArray_CheckExact(arr)) {
- sub = array_subscript(arr, mit->indexobj);
+ sub = array_subscript_simple(arr, mit->indexobj);
}
else {
Py_INCREF(arr);
obj = PyArray_EnsureArray((PyObject *)arr);
if (obj == NULL) goto fail;
- sub = array_subscript((PyArrayObject *)obj, mit->indexobj);
+ sub = array_subscript_simple((PyArrayObject *)obj, mit->indexobj);
Py_DECREF(obj);
}
diff --git a/numpy/core/tests/test_defmatrix.py b/numpy/core/tests/test_defmatrix.py
index b2735edca..0f5099599 100644
--- a/numpy/core/tests/test_defmatrix.py
+++ b/numpy/core/tests/test_defmatrix.py
@@ -4,7 +4,7 @@ import numpy.core;reload(numpy.core)
from numpy.core import *
restore_path()
-class test_ctor(ScipyTestCase):
+class test_ctor(NumpyTestCase):
def check_basic(self):
A = array([[1,2],[3,4]])
mA = matrix(A)
@@ -23,7 +23,7 @@ class test_ctor(ScipyTestCase):
mvec = matrix(vec)
assert mvec.shape == (1,5)
-class test_properties(ScipyTestCase):
+class test_properties(NumpyTestCase):
def check_sum(self):
"""Test whether matrix.sum(axis=1) preserves orientation.
Fails in NumPy <= 0.9.6.2127.
@@ -91,7 +91,7 @@ class test_properties(ScipyTestCase):
assert A.sum() == matrix(2)
assert A.mean() == matrix(0.5)
-class test_casting(ScipyTestCase):
+class test_casting(NumpyTestCase):
def check_basic(self):
A = arange(100).reshape(10,10)
mA = matrix(A)
@@ -109,7 +109,7 @@ class test_casting(ScipyTestCase):
assert mC.dtype.type == complex128
assert all(mA != mB)
-class test_algebra(ScipyTestCase):
+class test_algebra(NumpyTestCase):
def check_basic(self):
import numpy.linalg as linalg
@@ -132,7 +132,7 @@ class test_algebra(ScipyTestCase):
assert allclose((mA + mA).A, (A + A))
assert allclose((3*mA).A, (3*A))
-class test_matrix_return(ScipyTestCase):
+class test_matrix_return(NumpyTestCase):
def check_instance_methods(self):
a = matrix([1.0], dtype='f8')
methodargs = {
@@ -171,5 +171,14 @@ class test_matrix_return(ScipyTestCase):
assert type(c) is matrix
assert type(d) is matrix
+class test_indexing(NumpyTestCase):
+ def check_basic(self):
+ x = asmatrix(zeros((3,2),float))
+ y = zeros((3,1),float)
+ y[:,0] = [0.8,0.2,0.3]
+ x[:,1] = y>0.5
+ assert_equal(x, [[0,1],[0,0],[0,0]])
+
+
if __name__ == "__main__":
- ScipyTest().run()
+ NumpyTest().run()