summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/linalg/linalg.py14
1 files changed, 3 insertions, 11 deletions
diff --git a/numpy/linalg/linalg.py b/numpy/linalg/linalg.py
index ff9877549..f11f905f7 100644
--- a/numpy/linalg/linalg.py
+++ b/numpy/linalg/linalg.py
@@ -1892,15 +1892,7 @@ def _multi_svd_norm(x, row_axis, col_axis, op):
if row_axis > col_axis:
row_axis -= 1
y = rollaxis(rollaxis(x, col_axis, x.ndim), row_axis, -1)
- if x.ndim > 3:
- z = y.reshape((-1,) + y.shape[-2:])
- else:
- z = y
- if x.ndim == 2:
- result = op(svd(z, compute_uv=0))
- else:
- result = array([op(svd(m, compute_uv=0)) for m in z])
- result.shape = y.shape[:-2]
+ result = op(svd(y, compute_uv=0), axis=-1)
return result
@@ -2026,9 +2018,9 @@ def norm(x, ord=None, axis=None):
Using the `axis` argument to compute matrix norms:
>>> m = np.arange(8).reshape(2,2,2)
- >>> norm(m, axis=(1,2))
+ >>> LA.norm(m, axis=(1,2))
array([ 3.74165739, 11.22497216])
- >>> norm(m[0]), norm(m[1])
+ >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :])
(3.7416573867739413, 11.224972160321824)
"""