summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2013-07-14 09:57:54 -0600
committerCharles Harris <charlesr.harris@gmail.com>2013-08-12 22:33:56 -0600
commit6a02d054d2f74dc675497144a4819ce6be7835cb (patch)
treeeef0daa7f0f880a02b4778377573e68e161a2ddc /numpy/core
parent02f5258125debef5a0e5f6072805a8c72e4a1bde (diff)
downloadnumpy-6a02d054d2f74dc675497144a4819ce6be7835cb.tar.gz
TST: Add more extensive tests for the mean, var, and std methods.
Test the mean, var, and std methods more extensively. In addition, add tests that: Check that scalar return types ar what they should be. Check that ValueError is raised when ddof is too big.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/tests/test_multiarray.py172
-rw-r--r--numpy/core/tests/test_numeric.py2
2 files changed, 173 insertions, 1 deletions
diff --git a/numpy/core/tests/test_multiarray.py b/numpy/core/tests/test_multiarray.py
index b7235d05f..d37bcd331 100644
--- a/numpy/core/tests/test_multiarray.py
+++ b/numpy/core/tests/test_multiarray.py
@@ -2399,7 +2399,178 @@ class TestView(TestCase):
assert_array_equal(y, z)
assert_array_equal(y, [67305985, 134678021])
+def _mean(a, **args):
+ return a.mean(**args)
+
+def _var(a, **args):
+ return a.var(**args)
+
+def _std(a, **args):
+ return a.std(**args)
+
class TestStats(TestCase):
+
+ funcs = [_mean, _var, _std]
+
+ def setUp(self):
+ np.random.seed(range(3))
+ self.rmat = np.random.random((4, 5))
+ self.cmat = self.rmat + 1j * self.rmat
+
+ def test_keepdims(self):
+ mat = np.eye(3)
+ for f in self.funcs:
+ for axis in [0, 1]:
+ res = f(mat, axis=axis, keepdims=True)
+ assert_(res.ndim == mat.ndim)
+ assert_(res.shape[axis] == 1)
+ for axis in [None]:
+ res = f(mat, axis=axis, keepdims=True)
+ assert_(res.shape == (1, 1))
+
+ def test_out(self):
+ mat = np.eye(3)
+ for f in self.funcs:
+ out = np.zeros(3)
+ tgt = f(mat, axis=1)
+ res = f(mat, axis=1, out=out)
+ assert_almost_equal(res, out)
+ assert_almost_equal(res, tgt)
+ out = np.empty(2)
+ assert_raises(ValueError, f, mat, axis=1, out=out)
+ out = np.empty((2, 2))
+ assert_raises(ValueError, f, mat, axis=1, out=out)
+
+ def test_dtype_from_input(self):
+ icodes = np.typecodes['AllInteger']
+ fcodes = np.typecodes['AllFloat']
+
+ # integer types
+ for f in self.funcs:
+ for c in icodes:
+ mat = np.eye(3, dtype=c)
+ tgt = np.float64
+ res = f(mat, axis=1).dtype.type
+ assert_(res is tgt)
+ # scalar case
+ res = f(mat, axis=None).dtype.type
+ assert_(res is tgt)
+ # mean for float types
+ for f in [_mean]:
+ for c in fcodes:
+ mat = np.eye(3, dtype=c)
+ tgt = mat.dtype.type
+ res = f(mat, axis=1).dtype.type
+ assert_(res is tgt)
+ # scalar case
+ res = f(mat, axis=None).dtype.type
+ assert_(res is tgt)
+ # var, std for float types
+ for f in [_var, _std]:
+ for c in fcodes:
+ mat = np.eye(3, dtype=c)
+ tgt = mat.real.dtype.type
+ res = f(mat, axis=1).dtype.type
+ assert_(res is tgt)
+ # scalar case
+ res = f(mat, axis=None).dtype.type
+ assert_(res is tgt)
+
+ def test_dtype_from_dtype(self):
+ icodes = np.typecodes['AllInteger']
+ fcodes = np.typecodes['AllFloat']
+ mat = np.eye(3)
+
+ # stats for integer types
+ # fixme:
+ # this needs definition as there are lots places along the line
+ # where type casting may take place.
+ #for f in self.funcs:
+ #for c in icodes:
+ #tgt = np.dtype(c).type
+ #res = f(mat, axis=1, dtype=c).dtype.type
+ #assert_(res is tgt)
+ ## scalar case
+ #res = f(mat, axis=None, dtype=c).dtype.type
+ #assert_(res is tgt)
+
+ # stats for float types
+ for f in self.funcs:
+ for c in fcodes:
+ tgt = np.dtype(c).type
+ res = f(mat, axis=1, dtype=c).dtype.type
+ assert_(res is tgt)
+ # scalar case
+ res = f(mat, axis=None, dtype=c).dtype.type
+ assert_(res is tgt)
+
+ def test_ddof(self):
+ for f in [_var]:
+ for ddof in range(3):
+ dim = self.rmat.shape[1]
+ tgt = f(self.rmat, axis=1) * dim
+ res = f(self.rmat, axis=1, ddof=ddof) * (dim - ddof)
+ for f in [_std]:
+ for ddof in range(3):
+ dim = self.rmat.shape[1]
+ tgt = f(self.rmat, axis=1) * np.sqrt(dim)
+ res = f(self.rmat, axis=1, ddof=ddof) * np.sqrt(dim - ddof)
+ assert_almost_equal(res, tgt)
+ assert_almost_equal(res, tgt)
+
+ def test_ddof_too_big(self):
+ dim = self.rmat.shape[1]
+ for f in [_var, _std]:
+ for ddof in range(dim, dim + 2):
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ res = f(self.rmat, axis=1, ddof=ddof)
+ assert_(not (res < 0).any())
+ assert_(len(w) > 0)
+ assert_(issubclass(w[0].category, RuntimeWarning))
+
+ def test_empty(self):
+ A = np.zeros((0,3))
+ for f in self.funcs:
+ for axis in [0, None]:
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ assert_(np.isnan(f(A, axis=axis)).all())
+ assert_(len(w) > 0)
+ assert_(issubclass(w[0].category, RuntimeWarning))
+ for axis in [1]:
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter('always')
+ assert_equal(f(A, axis=axis), np.zeros([]))
+
+ def test_mean_values(self):
+ for mat in [self.rmat, self.cmat]:
+ for axis in [0, 1]:
+ tgt = mat.sum(axis=axis)
+ res = _mean(mat, axis=axis) * mat.shape[axis]
+ assert_almost_equal(res, tgt)
+ for axis in [None]:
+ tgt = mat.sum(axis=axis)
+ res = _mean(mat, axis=axis) * np.prod(mat.shape)
+ assert_almost_equal(res, tgt)
+
+ def test_var_values(self):
+ for mat in [self.rmat, self.cmat]:
+ for axis in [0, 1, None]:
+ msqr = _mean(mat * mat.conj(), axis=axis)
+ mean = _mean(mat, axis=axis)
+ tgt = msqr - mean * mean.conj()
+ res = _var(mat, axis=axis)
+ assert_almost_equal(res, tgt)
+
+ def test_std_values(self):
+ for mat in [self.rmat, self.cmat]:
+ for axis in [0, 1, None]:
+ tgt = np.sqrt(_var(mat, axis=axis))
+ res = _std(mat, axis=axis)
+ assert_almost_equal(res, tgt)
+
+
def test_subclass(self):
class TestArray(np.ndarray):
def __new__(cls, data, info):
@@ -2409,6 +2580,7 @@ class TestStats(TestCase):
return result
def __array_finalize__(self, obj):
self.info = getattr(obj, "info", '')
+
dat = TestArray([[1,2,3,4],[5,6,7,8]], 'jubba')
res = dat.mean(1)
assert_(res.info == dat.info)
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 2b6b07721..341884f04 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -202,7 +202,7 @@ class TestNonarrayArgs(TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.filterwarnings('always', '', RuntimeWarning)
- assert_(isnan(mean([])))
+ assert_(isnan(var([])))
assert_(w[0].category is RuntimeWarning)