summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorLars Buitinck <larsmans@gmail.com>2013-12-26 14:12:53 +0100
committerLars Buitinck <larsmans@gmail.com>2013-12-26 14:12:53 +0100
commitdff072619643c6d668446b229d44b7e4d45e6ec0 (patch)
tree81150b775318a245c5c85aa4b41e615eb39c70ae /numpy
parent61998c22df08e7ec0938bedef931ac824cfb634a (diff)
downloadnumpy-dff072619643c6d668446b229d44b7e4d45e6ec0.tar.gz
BUG: linalg: norm fails on longdouble, signed int
This fixes the following bug with longdouble: >>> x = np.arange(10, dtype=np.longdouble) >>> np.linalg.norm(x, ord=3) Traceback (most recent call last): File "<ipython-input-5-7ee53a8ac142>", line 1, in <module> np.linalg.norm(x, ord=3) File "/tmp/v/lib/python2.7/site-packages/numpy/linalg/linalg.py", line 2090, in norm return add.reduce(absx**ord, axis=axis)**(1.0/ord) UnboundLocalError: local variable 'absx' referenced before assignment As well as the handling of minimal values for signed integers: >>> x = np.array([-2**31], dtype=np.int32) >>> np.linalg.norm(x, ord=3) /tmp/v/lib/python2.7/site-packages/numpy/linalg/linalg.py:2090: RuntimeWarning: invalid value encountered in double_scalars return add.reduce(absx**ord, axis=axis)**(1.0/ord) nan
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py16
-rw-r--r--numpy/linalg/tests/test_linalg.py12
2 files changed, 24 insertions, 4 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index 1b82c7cc0..dd586827e 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -22,7 +22,7 @@ from numpy.core import (
array, asarray, zeros, empty, empty_like, transpose, intc, single, double,
csingle, cdouble, inexact, complexfloating, newaxis, ravel, all, Inf, dot,
add, multiply, sqrt, maximum, fastCopyAndTranspose, sum, isfinite, size,
- finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product,
+ finfo, errstate, geterrobj, longdouble, rollaxis, amin, amax, product, abs,
broadcast
)
from numpy.lib import triu, asfarray
@@ -2082,12 +2082,20 @@ def norm(x, ord=None, axis=None):
ord + 1
except TypeError:
raise ValueError("Invalid norm order for vectors.")
- if x.dtype.type is not longdouble:
+ if x.dtype.type is longdouble:
# Convert to a float type, so integer arrays give
# float results. Don't apply asfarray to longdouble arrays,
# because it will downcast to float64.
- absx = asfarray(abs(x))
- return add.reduce(absx**ord, axis=axis)**(1.0/ord)
+ absx = abs(x)
+ else:
+ absx = asfarray(x)
+ if absx.dtype is x.dtype:
+ absx = abs(absx)
+ else:
+ # if the type changed, we can safely overwrite absx
+ abs(absx, out=absx)
+ absx **= ord
+ return add.reduce(absx, axis=axis) ** (1.0 / ord)
elif len(axis) == 2:
row_axis, col_axis = axis
if not (-nd <= row_axis < nd and -nd <= col_axis < nd):
diff --git a/numpy/linalg/tests/test_linalg.py b/numpy/linalg/tests/test_linalg.py
index dd4cbcc4f..feb4c8224 100644
--- a/numpy/linalg/tests/test_linalg.py
+++ b/numpy/linalg/tests/test_linalg.py
@@ -909,6 +909,18 @@ class _TestNorm(object):
assert_raises(ValueError, norm, B, None, (2, 3))
assert_raises(ValueError, norm, B, None, (0, 1, 2))
+ def test_longdouble_norm(self):
+ # Non-regression test: p-norm of longdouble would previously raise
+ # UnboundLocalError.
+ x = np.arange(10, dtype=np.longdouble)
+ old_assert_almost_equal(norm(x, ord=3), 12.65, decimal=2)
+
+ def test_intmin(self):
+ # Non-regression test: p-norm of signed integer would previously do
+ # float cast and abs in the wrong order.
+ x = np.array([-2 ** 31], dtype=np.int32)
+ old_assert_almost_equal(norm(x, ord=3), 2 ** 31, decimal=5)
+
class TestNormDouble(_TestNorm):
dt = np.double