From 70026c4dde47d89d6a7a4916bfac045e714a5b4f Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 15 Jun 2022 18:44:15 -0600 Subject: MAINT: Fix computation of numpy.array_api.linalg.vector_norm (#21084) * Fix computation of numpy.array_api.linalg.vector_norm Various pieces were incorrect due to a lack of complete coverage of this function in the array API test suite. * Fix the output dtype nonstandard vector norm() Previously it would always give float64 because an internal calculation involved a NumPy scalar and a Python float. The fix is to use a 0-D array instead of a NumPy scalar so that it type promotes with the float correctly. Fixes #21083 I don't have a test for this yet because I'm unclear how exactly to test it. * Clean up the numpy.array_api.linalg.vector_norm code a little bit --- numpy/array_api/linalg.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) (limited to 'numpy/array_api') diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index f422e1c27..a4a2f23e4 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -1,8 +1,11 @@ from __future__ import annotations from ._dtypes import _floating_dtypes, _numeric_dtypes +from ._manipulation_functions import reshape from ._array_object import Array +from ..core.numeric import normalize_axis_tuple + from typing import TYPE_CHECKING if TYPE_CHECKING: from ._typing import Literal, Optional, Sequence, Tuple, Union @@ -395,18 +398,38 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in norm') + # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or + # when axis=None and the input is 2-D, so to force a vector norm, we make + # it so the input is 1-D (for axis=None), or reshape so that norm is done + # on a single dimension. a = x._array if axis is None: - a = a.flatten() - axis = 0 + # Note: np.linalg.norm() doesn't handle 0-D arrays + a = a.ravel() + _axis = 0 elif isinstance(axis, tuple): - # Note: The axis argument supports any number of axes, whereas norm() - # only supports a single axis for vector norm. - rest = tuple(i for i in range(a.ndim) if i not in axis) + # Note: The axis argument supports any number of axes, whereas + # np.linalg.norm() only supports a single axis for vector norm. + normalized_axis = normalize_axis_tuple(axis, x.ndim) + rest = tuple(i for i in range(a.ndim) if i not in normalized_axis) newshape = axis + rest - a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest])) - axis = 0 - return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord)) + a = np.transpose(a, newshape).reshape( + (np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest])) + _axis = 0 + else: + _axis = axis + + res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord)) + + if keepdims: + # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks + # above to avoid matrix norm logic. + shape = list(x.shape) + _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + for i in _axis: + shape[i] = 1 + res = reshape(res, tuple(shape)) + return res __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] -- cgit v1.2.1