diff options
author | Julian Taylor <juliantaylor108@gmail.com> | 2014-02-11 19:38:42 +0100 |
---|---|---|
committer | Julian Taylor <juliantaylor108@gmail.com> | 2014-02-11 19:38:42 +0100 |
commit | e2addbd77fbed5aa64a07f9b08e217c63e15467f (patch) | |
tree | d9dc3bdf81d77e822f8e660b2f2d277e03d63b17 | |
parent | bd5894b29b897f16da8a3d64e0df94e93d6b2d4a (diff) | |
parent | 4d686a1e44462c34b60de9049dd499aed60de643 (diff) | |
download | numpy-e2addbd77fbed5aa64a07f9b08e217c63e15467f.tar.gz |
Merge pull request #4276 from charris/fix-stat-methods
BUG: Fix mean, var, std methods for object arrays.
-rw-r--r-- | numpy/core/_methods.py | 12 | ||||
-rw-r--r-- | numpy/core/tests/test_multiarray.py | 52 |
2 files changed, 31 insertions, 33 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index a064f70c7..5f58940fd 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -63,8 +63,10 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False): if isinstance(ret, mu.ndarray): ret = um.true_divide( ret, rcount, out=ret, casting='unsafe', subok=False) - else: + elif hasattr(ret, 'dtype'): ret = ret.dtype.type(ret / rcount) + else: + ret = ret / rcount return ret @@ -107,8 +109,10 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): if isinstance(ret, mu.ndarray): ret = um.true_divide( ret, rcount, out=ret, casting='unsafe', subok=False) - else: + elif hasattr(ret, 'dtype'): ret = ret.dtype.type(ret / rcount) + else: + ret = ret / rcount return ret @@ -118,7 +122,9 @@ def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): if isinstance(ret, mu.ndarray): ret = um.sqrt(ret, out=ret) - else: + elif hasattr(ret, 'dtype'): ret = ret.dtype.type(um.sqrt(ret)) + else: + ret = um.sqrt(ret) return ret diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py index d888f721a..e6051bd99 100644 --- a/numpy/core/tests/test_multiarray.py +++ b/numpy/core/tests/test_multiarray.py @@ -10,6 +10,8 @@ if sys.version_info[0] >= 3: import builtins else: import __builtin__ as builtins +from decimal import Decimal + import numpy as np from nose import SkipTest @@ -2833,6 +2835,8 @@ class TestStats(TestCase): np.random.seed(range(3)) self.rmat = np.random.random((4, 5)) self.cmat = self.rmat + 1j * self.rmat + self.omat = np.array([Decimal(repr(r)) for r in self.rmat.flat]) + self.omat = self.omat.reshape(4, 5) def test_keepdims(self): mat = np.eye(3) @@ -2859,9 +2863,20 @@ class TestStats(TestCase): assert_raises(ValueError, f, mat, axis=1, out=out) def test_dtype_from_input(self): + icodes = np.typecodes['AllInteger'] fcodes = np.typecodes['AllFloat'] + # object type + for f in self.funcs: + mat = np.array([[Decimal(1)]*3]*3) + tgt = mat.dtype.type + res = f(mat, axis=1).dtype.type + assert_(res is tgt) + # scalar case + res = type(f(mat, axis=None)) + assert_(res is Decimal) + # integer types for f in self.funcs: for c in icodes: @@ -2872,6 +2887,7 @@ class TestStats(TestCase): # scalar case res = f(mat, axis=None).dtype.type assert_(res is tgt) + # mean for float types for f in [_mean]: for c in fcodes: @@ -2882,10 +2898,12 @@ class TestStats(TestCase): # scalar case res = f(mat, axis=None).dtype.type assert_(res is tgt) + # var, std for float types for f in [_var, _std]: for c in fcodes: mat = np.eye(3, dtype=c) + # deal with complex types tgt = mat.real.dtype.type res = f(mat, axis=1).dtype.type assert_(res is tgt) @@ -2961,7 +2979,7 @@ class TestStats(TestCase): assert_equal(f(A, axis=axis), np.zeros([])) def test_mean_values(self): - for mat in [self.rmat, self.cmat]: + for mat in [self.rmat, self.cmat, self.omat]: for axis in [0, 1]: tgt = mat.sum(axis=axis) res = _mean(mat, axis=axis) * mat.shape[axis] @@ -2972,16 +2990,16 @@ class TestStats(TestCase): assert_almost_equal(res, tgt) def test_var_values(self): - for mat in [self.rmat, self.cmat]: + for mat in [self.rmat, self.cmat, self.omat]: for axis in [0, 1, None]: msqr = _mean(mat * mat.conj(), axis=axis) mean = _mean(mat, axis=axis) - tgt = msqr - mean * mean.conj() + tgt = msqr - mean * mean.conjugate() res = _var(mat, axis=axis) assert_almost_equal(res, tgt) def test_std_values(self): - for mat in [self.rmat, self.cmat]: + for mat in [self.rmat, self.cmat, self.omat]: for axis in [0, 1, None]: tgt = np.sqrt(_var(mat, axis=axis)) res = _std(mat, axis=axis) @@ -3109,12 +3127,6 @@ class TestChoose(TestCase): A = np.choose(self.ind, (self.x, self.y2)) assert_equal(A, [[2, 2, 3], [2, 2, 3]]) -def can_use_decimal(): - try: - from decimal import Decimal - return True - except ImportError: - return False # TODO: test for multidimensional NEIGH_MODE = {'zero': 0, 'one': 1, 'constant': 2, 'circular': 3, 'mirror': 4} @@ -3150,11 +3162,7 @@ class TestNeighborhoodIter(TestCase): def test_simple2d(self): self._test_simple2d(np.float) - @dec.skipif(not can_use_decimal(), - "Skip neighborhood iterator tests for decimal objects " \ - "(decimal module not available") def test_simple2d_object(self): - from decimal import Decimal self._test_simple2d(Decimal) def _test_mirror2d(self, dt): @@ -3170,11 +3178,7 @@ class TestNeighborhoodIter(TestCase): def test_mirror2d(self): self._test_mirror2d(np.float) - @dec.skipif(not can_use_decimal(), - "Skip neighborhood iterator tests for decimal objects " \ - "(decimal module not available") def test_mirror2d_object(self): - from decimal import Decimal self._test_mirror2d(Decimal) # Simple, 1d tests @@ -3196,11 +3200,7 @@ class TestNeighborhoodIter(TestCase): def test_simple_float(self): self._test_simple(np.float) - @dec.skipif(not can_use_decimal(), - "Skip neighborhood iterator tests for decimal objects " \ - "(decimal module not available") def test_simple_object(self): - from decimal import Decimal self._test_simple(Decimal) # Test mirror modes @@ -3215,11 +3215,7 @@ class TestNeighborhoodIter(TestCase): def test_mirror(self): self._test_mirror(np.float) - @dec.skipif(not can_use_decimal(), - "Skip neighborhood iterator tests for decimal objects " \ - "(decimal module not available") def test_mirror_object(self): - from decimal import Decimal self._test_mirror(Decimal) # Circular mode @@ -3233,11 +3229,7 @@ class TestNeighborhoodIter(TestCase): def test_circular(self): self._test_circular(np.float) - @dec.skipif(not can_use_decimal(), - "Skip neighborhood iterator tests for decimal objects " \ - "(decimal module not available") def test_circular_object(self): - from decimal import Decimal self._test_circular(Decimal) # Test stacking neighborhood iterators |