summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2013-10-05 09:03:37 -0700
committerCharles Harris <charlesr.harris@gmail.com>2013-10-05 09:03:37 -0700
commit0cfa4ed4ee39aaa94e4059c6394a4ed75a8e3d6c (patch)
treeec3cf1089baae1b9b0838957d4e44769b3583109 /numpy/core
parentc2dc2cdb73530805b77a75efdd106d7633f2fff3 (diff)
parent2f77e1e6e6b91a9cd11c422342c69e8fd68ee803 (diff)
downloadnumpy-0cfa4ed4ee39aaa94e4059c6394a4ed75a8e3d6c.tar.gz
Merge pull request #3866 from charris/refactor-1.9-nanfunctions
Refactor 1.9 nanfunctions
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/_methods.py5
-rw-r--r--numpy/core/tests/test_numeric.py8
2 files changed, 11 insertions, 2 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 8f0027616..c8a968c97 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -91,8 +91,9 @@ def _var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False):
arrmean = arrmean.dtype.type(arrmean / rcount)
# Compute sum of squared deviations from mean
- # Note that x may not be inexact
- x = arr - arrmean
+ # Note that x may not be inexact and that we need it to be an array,
+ # not a scalar.
+ x = asanyarray(arr - arrmean)
if issubclass(arr.dtype.type, nt.complexfloating):
x = um.multiply(x, um.conjugate(x), out=x).real
else:
diff --git a/numpy/core/tests/test_numeric.py b/numpy/core/tests/test_numeric.py
index 629cb6090..8dc2ebd71 100644
--- a/numpy/core/tests/test_numeric.py
+++ b/numpy/core/tests/test_numeric.py
@@ -1473,6 +1473,10 @@ class TestStdVar(TestCase):
assert_almost_equal(var(self.A), self.real_var)
assert_almost_equal(std(self.A)**2, self.real_var)
+ def test_scalars(self):
+ assert_equal(var(1), 0)
+ assert_equal(std(1), 0)
+
def test_ddof1(self):
assert_almost_equal(var(self.A, ddof=1),
self.real_var*len(self.A)/float(len(self.A)-1))
@@ -1492,6 +1496,10 @@ class TestStdVarComplex(TestCase):
assert_almost_equal(var(A), real_var)
assert_almost_equal(std(A)**2, real_var)
+ def test_scalars(self):
+ assert_equal(var(1j), 0)
+ assert_equal(std(1j), 0)
+
class TestCreationFuncs(TestCase):
#Test ones, zeros, empty and filled