summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2016-01-22 18:28:35 -0700
committerCharles Harris <charlesr.harris@gmail.com>2016-01-22 18:28:35 -0700
commit66156a049aa833a06e6747dbe88676a61e034e8d (patch)
treea63f92ebed409e2ad084642d1ad1f1fba9441a79
parentd5cef016b336659a2288cefe8aa6f60cf340d35c (diff)
parent75d5b59bca181ee7e5ba872999014006c4b6c3f3 (diff)
downloadnumpy-66156a049aa833a06e6747dbe88676a61e034e8d.tar.gz
Merge pull request #7088 from jakirkham/cast_float_linalg
BUG: Have `norm` cast non-floating point arrays to 64-bit float arrays
-rw-r--r--numpy/linalg/linalg.py21
-rw-r--r--numpy/linalg/tests/test_linalg.py93
2 files changed, 105 insertions, 9 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 9dc879d31..9d486d2a5 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -2060,22 +2060,22 @@ def norm(x, ord=None, axis=None, keepdims=False):
>>> LA.norm(b, 'fro')
7.745966692414834
>>> LA.norm(a, np.inf)
- 4
+ 4.0
>>> LA.norm(b, np.inf)
- 9
+ 9.0
>>> LA.norm(a, -np.inf)
- 0
+ 0.0
>>> LA.norm(b, -np.inf)
- 2
+ 2.0
>>> LA.norm(a, 1)
- 20
+ 20.0
>>> LA.norm(b, 1)
- 7
+ 7.0
>>> LA.norm(a, -1)
-4.6566128774142013e-010
>>> LA.norm(b, -1)
- 6
+ 6.0
>>> LA.norm(a, 2)
7.745966692414834
>>> LA.norm(b, 2)
@@ -2099,7 +2099,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
>>> LA.norm(c, axis=1)
array([ 3.74165739, 4.24264069])
>>> LA.norm(c, ord=1, axis=1)
- array([6, 6])
+ array([ 6., 6.])
Using the `axis` argument to compute matrix norms:
@@ -2112,6 +2112,9 @@ def norm(x, ord=None, axis=None, keepdims=False):
"""
x = asarray(x)
+ if not issubclass(x.dtype.type, inexact):
+ x = x.astype(float)
+
# Immediately handle some default, simple, fast, and common cases.
if axis is None:
ndim = x.ndim
@@ -2147,7 +2150,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
return abs(x).min(axis=axis, keepdims=keepdims)
elif ord == 0:
# Zero norm
- return (x != 0).sum(axis=axis, keepdims=keepdims)
+ return (x != 0).astype(float).sum(axis=axis, keepdims=keepdims)
elif ord == 1:
# special case for speedup
return add.reduce(abs(x), axis=axis, keepdims=keepdims)
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index fc139be19..60486d4ce 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -7,6 +7,7 @@ import os
import sys
import itertools
import traceback
+import warnings
import numpy as np
from numpy import array, single, double, csingle, cdouble, dot, identity
@@ -845,6 +846,98 @@ class _TestNorm(object):
assert_equal(norm(array([], dtype=self.dt)), 0.0)
assert_equal(norm(atleast_2d(array([], dtype=self.dt))), 0.0)
+ def test_vector_return_type(self):
+ a = np.array([1, 0, 1])
+
+ exact_types = np.typecodes['AllInteger']
+ inexact_types = np.typecodes['AllFloat']
+
+ all_types = exact_types + inexact_types
+
+ for each_inexact_types in all_types:
+ at = a.astype(each_inexact_types)
+
+ an = norm(at, -np.inf)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 0.0)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", RuntimeWarning)
+ an = norm(at, -1)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 0.0)
+
+ an = norm(at, 0)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 2)
+
+ an = norm(at, 1)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 2.0)
+
+ an = norm(at, 2)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 2.0**(1.0/2.0))
+
+ an = norm(at, 4)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 2.0**(1.0/4.0))
+
+ an = norm(at, np.inf)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 1.0)
+
+ def test_matrix_return_type(self):
+ a = np.array([[1, 0, 1], [0, 1, 1]])
+
+ exact_types = np.typecodes['AllInteger']
+
+ # float32, complex64, float64, complex128 types are the only types
+ # allowed by `linalg`, which performs the matrix operations used
+ # within `norm`.
+ inexact_types = 'fdFD'
+
+ all_types = exact_types + inexact_types
+
+ for each_inexact_types in all_types:
+ at = a.astype(each_inexact_types)
+
+ an = norm(at, -np.inf)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 2.0)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", RuntimeWarning)
+ an = norm(at, -1)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 1.0)
+
+ an = norm(at, 1)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 2.0)
+
+ an = norm(at, 2)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 3.0**(1.0/2.0))
+
+ an = norm(at, -2)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 1.0)
+
+ an = norm(at, np.inf)
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 2.0)
+
+ an = norm(at, 'fro')
+ assert_(issubclass(an.dtype.type, np.floating))
+ assert_almost_equal(an, 2.0)
+
+ an = norm(at, 'nuc')
+ assert_(issubclass(an.dtype.type, np.floating))
+ # Lower bar needed to support low precision floats.
+ # They end up being off by 1 in the 7th place.
+ old_assert_almost_equal(an, 2.7320508075688772, decimal=6)
+
def test_vector(self):
a = [1, 2, 3, 4]
b = [-1, -2, -3, -4]