summaryrefslogtreecommitdiff
path: root/numpy/array_api/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/linalg.py')
-rw-r--r--numpy/array_api/linalg.py39
1 files changed, 31 insertions, 8 deletions
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py
index f422e1c27..a4a2f23e4 100644
--- a/numpy/array_api/linalg.py
+++ b/numpy/array_api/linalg.py
@@ -1,8 +1,11 @@
from __future__ import annotations
from ._dtypes import _floating_dtypes, _numeric_dtypes
+from ._manipulation_functions import reshape
from ._array_object import Array
+from ..core.numeric import normalize_axis_tuple
+
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ._typing import Literal, Optional, Sequence, Tuple, Union
@@ -395,18 +398,38 @@ def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, ...]]] = No
if x.dtype not in _floating_dtypes:
raise TypeError('Only floating-point dtypes are allowed in norm')
+ # np.linalg.norm tries to do a matrix norm whenever axis is a 2-tuple or
+ # when axis=None and the input is 2-D, so to force a vector norm, we make
+ # it so the input is 1-D (for axis=None), or reshape so that norm is done
+ # on a single dimension.
a = x._array
if axis is None:
- a = a.flatten()
- axis = 0
+ # Note: np.linalg.norm() doesn't handle 0-D arrays
+ a = a.ravel()
+ _axis = 0
elif isinstance(axis, tuple):
- # Note: The axis argument supports any number of axes, whereas norm()
- # only supports a single axis for vector norm.
- rest = tuple(i for i in range(a.ndim) if i not in axis)
+ # Note: The axis argument supports any number of axes, whereas
+ # np.linalg.norm() only supports a single axis for vector norm.
+ normalized_axis = normalize_axis_tuple(axis, x.ndim)
+ rest = tuple(i for i in range(a.ndim) if i not in normalized_axis)
newshape = axis + rest
- a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest]))
- axis = 0
- return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord))
+ a = np.transpose(a, newshape).reshape(
+ (np.prod([a.shape[i] for i in axis], dtype=int), *[a.shape[i] for i in rest]))
+ _axis = 0
+ else:
+ _axis = axis
+
+ res = Array._new(np.linalg.norm(a, axis=_axis, ord=ord))
+
+ if keepdims:
+ # We can't reuse np.linalg.norm(keepdims) because of the reshape hacks
+ # above to avoid matrix norm logic.
+ shape = list(x.shape)
+ _axis = normalize_axis_tuple(range(x.ndim) if axis is None else axis, x.ndim)
+ for i in _axis:
+ shape[i] = 1
+ res = reshape(res, tuple(shape))
+ return res
__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']