diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/oldnumeric.py | 2 | ||||
-rw-r--r-- | numpy/core/tests/test_oldnumeric.py | 18 |
2 files changed, 17 insertions, 3 deletions
diff --git a/numpy/core/oldnumeric.py b/numpy/core/oldnumeric.py index 978104469..b8685d69e 100644 --- a/numpy/core/oldnumeric.py +++ b/numpy/core/oldnumeric.py @@ -166,7 +166,7 @@ def _wrapit(obj, method, *args, **kwds): except AttributeError: wrap = None result = getattr(asarray(obj),method)(*args, **kwds) - if wrap: + if wrap and isinstance(result, mu.ndarray): if not isinstance(result, mu.ndarray): result = asarray(result) result = wrap(result) diff --git a/numpy/core/tests/test_oldnumeric.py b/numpy/core/tests/test_oldnumeric.py index 2821aa899..df8d9a3db 100644 --- a/numpy/core/tests/test_oldnumeric.py +++ b/numpy/core/tests/test_oldnumeric.py @@ -1,6 +1,6 @@ from numpy.testing import * -from numpy import array +from numpy import array, ndarray, arange, argmax from numpy.core.oldnumeric import put class test_put(ScipyTestCase): @@ -9,7 +9,21 @@ class test_put(ScipyTestCase): put(a,[1],[1.2]) assert_array_equal(a,[0,1,0]) put(a,[1],array([2.2])) - assert_array_equal(a,[0,2,0]) + assert_array_equal(a,[0,2,0])
+
+class test_wrapit(ScipyTestCase):
+ def check_array_subclass(self, level=1):
+ class subarray(ndarray): + def get_argmax(self):
+ raise AttributeError
+ argmax = property(get_argmax)
+ a = subarray([3], int, arange(3))
+ assert_equal(argmax(a), 2)
+ b = subarray([3, 3], int, arange(9))
+ bmax = argmax(b)
+ assert_array_equal(bmax, [2,2,2])
+ assert_equal(type(bmax), subarray)
+ if __name__ == "__main__": ScipyTest().run() |