diff options
Diffstat (limited to 'numpy/array_api/linalg.py')
-rw-r--r-- | numpy/array_api/linalg.py | 39 |
1 files changed, 31 insertions, 8 deletions
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index f422e1c27..a4a2f23e4 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -1,8 +1,11 @@ from __future__ import annotations from ._dtypes import _floating_dtypes, _numeric_dtypes +from ._manipulation_functions import reshape from ._array_object import Array +from ..core.numeric import normalize_axis_tuple + from typing import TYPE_CHECKING if TYPE_CHECKING: from ._typing import Literal, Optional, Sequence, Tuple, Union @@ -395,18 +398,38 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No if x.dtype not in _floating_dtypes: raise TypeError('Only floating-point dtypes are allowed in norm') + # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or + # when axis=None and the input is 2-D, so to force a vector norm, we make + # it so the input is 1-D (for axis=None), or reshape so that norm is done + # on a single dimension. a = x._array if axis is None: - a = a.flatten() - axis = 0 + # Note: np.linalg.norm() doesn't handle 0-D arrays + a = a.ravel() + _axis = 0 elif isinstance(axis, tuple): - # Note: The axis argument supports any number of axes, whereas norm() - # only supports a single axis for vector norm. - rest = tuple(i for i in range(a.ndim) if i not in axis) + # Note: The axis argument supports any number of axes, whereas + # np.linalg.norm() only supports a single axis for vector norm. + normalized_axis = normalize_axis_tuple(axis, x.ndim) + rest = tuple(i for i in range(a.ndim) if i not in normalized_axis) newshape = axis + rest - a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest])) - axis = 0 - return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord)) + a = np.transpose(a, newshape).reshape( + (np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest])) + _axis = 0 + else: + _axis = axis + + res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord)) + + if keepdims: + # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks + # above to avoid matrix norm logic. + shape = list(x.shape) + _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim) + for i in _axis: + shape[i] = 1 + res = reshape(res, tuple(shape)) + return res __all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm'] |