summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJulian Taylor <juliantaylor108@gmail.com>2014-02-11 19:38:42 +0100
committerJulian Taylor <juliantaylor108@gmail.com>2014-02-11 19:38:42 +0100
commite2addbd77fbed5aa64a07f9b08e217c63e15467f (patch)
treed9dc3bdf81d77e822f8e660b2f2d277e03d63b17
parentbd5894b29b897f16da8a3d64e0df94e93d6b2d4a (diff)
parent4d686a1e44462c34b60de9049dd499aed60de643 (diff)
downloadnumpy-e2addbd77fbed5aa64a07f9b08e217c63e15467f.tar.gz
Merge pull request #4276 from charris/fix-stat-methods
BUG: Fix mean, var, std methods for object arrays.
-rw-r--r--numpy/core/_methods.py12
-rw-r--r--numpy/core/tests/test_multiarray.py52
2 files changed, 31 insertions, 33 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index a064f70c7..5f58940fd 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -63,8 +63,10 @@ def _mean(a, axis=None, dtype=None, out=None, keepdims=False):
if isinstance(ret, mu.ndarray):
ret = um.true_divide(
ret, rcount, out=ret, casting='unsafe', subok=False)
- else:
+ elif hasattr(ret, 'dtype'):
ret = ret.dtype.type(ret / rcount)
+ else:
+ ret = ret / rcount
return ret
@@ -107,8 +109,10 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
if isinstance(ret, mu.ndarray):
ret = um.true_divide(
ret, rcount, out=ret, casting='unsafe', subok=False)
- else:
+ elif hasattr(ret, 'dtype'):
ret = ret.dtype.type(ret / rcount)
+ else:
+ ret = ret / rcount
return ret
@@ -118,7 +122,9 @@ def _std(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
if isinstance(ret, mu.ndarray):
ret = um.sqrt(ret, out=ret)
- else:
+ elif hasattr(ret, 'dtype'):
ret = ret.dtype.type(um.sqrt(ret))
+ else:
+ ret = um.sqrt(ret)
return ret
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index d888f721a..e6051bd99 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -10,6 +10,8 @@ if sys.version_info[0] >= 3:
import builtins
else:
import __builtin__ as builtins
+from decimal import Decimal
+
import numpy as np
from nose import SkipTest
@@ -2833,6 +2835,8 @@ class TestStats(TestCase):
np.random.seed(range(3))
self.rmat = np.random.random((4, 5))
self.cmat = self.rmat + 1j * self.rmat
+ self.omat = np.array([Decimal(repr(r)) for r in self.rmat.flat])
+ self.omat = self.omat.reshape(4, 5)
def test_keepdims(self):
mat = np.eye(3)
@@ -2859,9 +2863,20 @@ class TestStats(TestCase):
assert_raises(ValueError, f, mat, axis=1, out=out)
def test_dtype_from_input(self):
+
icodes = np.typecodes['AllInteger']
fcodes = np.typecodes['AllFloat']
+ # object type
+ for f in self.funcs:
+ mat = np.array([[Decimal(1)]*3]*3)
+ tgt = mat.dtype.type
+ res = f(mat, axis=1).dtype.type
+ assert_(res is tgt)
+ # scalar case
+ res = type(f(mat, axis=None))
+ assert_(res is Decimal)
+
# integer types
for f in self.funcs:
for c in icodes:
@@ -2872,6 +2887,7 @@ class TestStats(TestCase):
# scalar case
res = f(mat, axis=None).dtype.type
assert_(res is tgt)
+
# mean for float types
for f in [_mean]:
for c in fcodes:
@@ -2882,10 +2898,12 @@ class TestStats(TestCase):
# scalar case
res = f(mat, axis=None).dtype.type
assert_(res is tgt)
+
# var, std for float types
for f in [_var, _std]:
for c in fcodes:
mat = np.eye(3, dtype=c)
+ # deal with complex types
tgt = mat.real.dtype.type
res = f(mat, axis=1).dtype.type
assert_(res is tgt)
@@ -2961,7 +2979,7 @@ class TestStats(TestCase):
assert_equal(f(A, axis=axis), np.zeros([]))
def test_mean_values(self):
- for mat in [self.rmat, self.cmat]:
+ for mat in [self.rmat, self.cmat, self.omat]:
for axis in [0, 1]:
tgt = mat.sum(axis=axis)
res = _mean(mat, axis=axis) * mat.shape[axis]
@@ -2972,16 +2990,16 @@ class TestStats(TestCase):
assert_almost_equal(res, tgt)
def test_var_values(self):
- for mat in [self.rmat, self.cmat]:
+ for mat in [self.rmat, self.cmat, self.omat]:
for axis in [0, 1, None]:
msqr = _mean(mat * mat.conj(), axis=axis)
mean = _mean(mat, axis=axis)
- tgt = msqr - mean * mean.conj()
+ tgt = msqr - mean * mean.conjugate()
res = _var(mat, axis=axis)
assert_almost_equal(res, tgt)
def test_std_values(self):
- for mat in [self.rmat, self.cmat]:
+ for mat in [self.rmat, self.cmat, self.omat]:
for axis in [0, 1, None]:
tgt = np.sqrt(_var(mat, axis=axis))
res = _std(mat, axis=axis)
@@ -3109,12 +3127,6 @@ class TestChoose(TestCase):
A = np.choose(self.ind, (self.x, self.y2))
assert_equal(A, [[2, 2, 3], [2, 2, 3]])
-def can_use_decimal():
- try:
- from decimal import Decimal
- return True
- except ImportError:
- return False
# TODO: test for multidimensional
NEIGH_MODE = {'zero': 0, 'one': 1, 'constant': 2, 'circular': 3, 'mirror': 4}
@@ -3150,11 +3162,7 @@ class TestNeighborhoodIter(TestCase):
def test_simple2d(self):
self._test_simple2d(np.float)
- @dec.skipif(not can_use_decimal(),
- "Skip neighborhood iterator tests for decimal objects " \
- "(decimal module not available")
def test_simple2d_object(self):
- from decimal import Decimal
self._test_simple2d(Decimal)
def _test_mirror2d(self, dt):
@@ -3170,11 +3178,7 @@ class TestNeighborhoodIter(TestCase):
def test_mirror2d(self):
self._test_mirror2d(np.float)
- @dec.skipif(not can_use_decimal(),
- "Skip neighborhood iterator tests for decimal objects " \
- "(decimal module not available")
def test_mirror2d_object(self):
- from decimal import Decimal
self._test_mirror2d(Decimal)
# Simple, 1d tests
@@ -3196,11 +3200,7 @@ class TestNeighborhoodIter(TestCase):
def test_simple_float(self):
self._test_simple(np.float)
- @dec.skipif(not can_use_decimal(),
- "Skip neighborhood iterator tests for decimal objects " \
- "(decimal module not available")
def test_simple_object(self):
- from decimal import Decimal
self._test_simple(Decimal)
# Test mirror modes
@@ -3215,11 +3215,7 @@ class TestNeighborhoodIter(TestCase):
def test_mirror(self):
self._test_mirror(np.float)
- @dec.skipif(not can_use_decimal(),
- "Skip neighborhood iterator tests for decimal objects " \
- "(decimal module not available")
def test_mirror_object(self):
- from decimal import Decimal
self._test_mirror(Decimal)
# Circular mode
@@ -3233,11 +3229,7 @@ class TestNeighborhoodIter(TestCase):
def test_circular(self):
self._test_circular(np.float)
- @dec.skipif(not can_use_decimal(),
- "Skip neighborhood iterator tests for decimal objects " \
- "(decimal module not available")
def test_circular_object(self):
- from decimal import Decimal
self._test_circular(Decimal)
# Test stacking neighborhood iterators