summaryrefslogtreecommitdiff
path: root/numpy/array_api
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api')
-rw-r--r--numpy/array_api/__init__.py12
-rw-r--r--numpy/array_api/_array_object.py2
-rw-r--r--numpy/array_api/_linear_algebra_functions.py67
-rw-r--r--numpy/array_api/_statistical_functions.py9
-rw-r--r--numpy/array_api/linalg.py408
5 files changed, 419 insertions, 79 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py
index 36e3f3ed5..bbe2fdce2 100644
--- a/numpy/array_api/__init__.py
+++ b/numpy/array_api/__init__.py
@@ -109,9 +109,6 @@ Still TODO in this module are:
- The spec is still in an RFC phase and may still have minor updates, which
will need to be reflected here.
-- The linear algebra extension in the spec will be added in a future pull
- request.
-
- Complex number support in array API spec is planned but not yet finalized,
as are the fft extension and certain linear algebra functions such as eig
that require complex dtypes.
@@ -334,12 +331,13 @@ __all__ += [
"trunc",
]
-# einsum is not yet implemented in the array API spec.
+# linalg is an extension in the array API spec, which is a sub-namespace. Only
+# a subset of functions in it are imported into the top-level namespace.
+from . import linalg
-# from ._linear_algebra_functions import einsum
-# __all__ += ['einsum']
+__all__ += ["linalg"]
-from ._linear_algebra_functions import matmul, tensordot, matrix_transpose, vecdot
+from .linalg import matmul, tensordot, matrix_transpose, vecdot
__all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index dc74bb8c5..8794c5ea5 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -1030,7 +1030,7 @@ class Array:
# Note: mT is new in array API spec (see matrix_transpose)
@property
def mT(self) -> Array:
- from ._linear_algebra_functions import matrix_transpose
+ from .linalg import matrix_transpose
return matrix_transpose(self)
@property
diff --git a/numpy/array_api/_linear_algebra_functions.py b/numpy/array_api/_linear_algebra_functions.py
deleted file mode 100644
index 7a6c9846c..000000000
--- a/numpy/array_api/_linear_algebra_functions.py
+++ /dev/null
@@ -1,67 +0,0 @@
-from __future__ import annotations
-
-from ._array_object import Array
-from ._dtypes import _numeric_dtypes, _result_type
-
-from typing import Optional, Sequence, Tuple, Union
-
-import numpy as np
-
-# einsum is not yet implemented in the array API spec.
-
-# def einsum():
-# """
-# Array API compatible wrapper for :py:func:`np.einsum <numpy.einsum>`.
-#
-# See its docstring for more information.
-# """
-# return np.einsum()
-
-
-def matmul(x1: Array, x2: Array, /) -> Array:
- """
- Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
-
- See its docstring for more information.
- """
- # Note: the restriction to numeric dtypes only is different from
- # np.matmul.
- if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError("Only numeric dtypes are allowed in matmul")
- # Call result type here just to raise on disallowed type combinations
- _result_type(x1.dtype, x2.dtype)
-
- return Array._new(np.matmul(x1._array, x2._array))
-
-
-# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
-def tensordot(
- x1: Array,
- x2: Array,
- /,
- *,
- axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2,
-) -> Array:
- # Note: the restriction to numeric dtypes only is different from
- # np.tensordot.
- if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
- raise TypeError("Only numeric dtypes are allowed in tensordot")
- # Call result type here just to raise on disallowed type combinations
- _result_type(x1.dtype, x2.dtype)
-
- return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
-
-
-# Note: this function is new in the array API spec. Unlike transpose, it only
-# transposes the last two axes.
-def matrix_transpose(x: Array, /) -> Array:
- if x.ndim < 2:
- raise ValueError("x must be at least 2-dimensional for matrix_transpose")
- return Array._new(np.swapaxes(x._array, -1, -2))
-
-
-# Note: vecdot is not in NumPy
-def vecdot(x1: Array, x2: Array, /, *, axis: Optional[int] = None) -> Array:
- if axis is None:
- axis = -1
- return tensordot(x1, x2, axes=((axis,), (axis,)))
diff --git a/numpy/array_api/_statistical_functions.py b/numpy/array_api/_statistical_functions.py
index c5abf9468..7bee3f4db 100644
--- a/numpy/array_api/_statistical_functions.py
+++ b/numpy/array_api/_statistical_functions.py
@@ -93,11 +93,12 @@ def sum(
) -> Array:
if x.dtype not in _numeric_dtypes:
raise TypeError("Only numeric dtypes are allowed in sum")
- # Note: sum() and prod() always upcast float32 to float64 for dtype=None
- # We need to do so here before summing to avoid overflow
+ # Note: sum() and prod() always upcast integers to (u)int64 and float32 to
+ # float64 for dtype=None. `np.sum` does that too for integers, but not for
+ # float32, so we need to special-case it here
if dtype is None and x.dtype == float32:
- x = asarray(x, dtype=float64)
- return Array._new(np.sum(x._array, axis=axis, keepdims=keepdims))
+ dtype = float64
+ return Array._new(np.sum(x._array, axis=axis, dtype=dtype, keepdims=keepdims))
def var(
diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py
new file mode 100644
index 000000000..8d7ba659e
--- /dev/null
+++ b/numpy/array_api/linalg.py
@@ -0,0 +1,408 @@
+from __future__ import annotations
+
+from ._dtypes import _floating_dtypes, _numeric_dtypes
+from ._array_object import Array
+
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from ._typing import Literal, Optional, Sequence, Tuple, Union
+
+from typing import NamedTuple
+
+import numpy.linalg
+import numpy as np
+
+class EighResult(NamedTuple):
+ eigenvalues: Array
+ eigenvectors: Array
+
+class QRResult(NamedTuple):
+ Q: Array
+ R: Array
+
+class SlogdetResult(NamedTuple):
+ sign: Array
+ logabsdet: Array
+
+class SVDResult(NamedTuple):
+ U: Array
+ S: Array
+ Vh: Array
+
+# Note: the inclusion of the upper keyword is different from
+# np.linalg.cholesky, which does not have it.
+def cholesky(x: Array, /, *, upper: bool = False) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.cholesky <numpy.linalg.cholesky>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.cholesky.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in cholesky')
+ L = np.linalg.cholesky(x._array)
+ if upper:
+ return Array._new(L).mT
+ return Array._new(L)
+
+# Note: cross is the numpy top-level namespace, not np.linalg
+def cross(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.cross <numpy.cross>`.
+
+ See its docstring for more information.
+ """
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in cross')
+ # Note: this is different from np.cross(), which broadcasts
+ if x1.shape != x2.shape:
+ raise ValueError('x1 and x2 must have the same shape')
+ if x1.ndim == 0:
+ raise ValueError('cross() requires arrays of dimension at least 1')
+ # Note: this is different from np.cross(), which allows dimension 2
+ if x1.shape[axis] != 3:
+ raise ValueError('cross() dimension must equal 3')
+ return Array._new(np.cross(x1._array, x2._array, axis=axis))
+
+def det(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.det <numpy.linalg.det>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.det.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in det')
+ return Array._new(np.linalg.det(x._array))
+
+# Note: diagonal is the numpy top-level namespace, not np.linalg
+def diagonal(x: Array, /, *, offset: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.diagonal <numpy.diagonal>`.
+
+ See its docstring for more information.
+ """
+ # Note: diagonal always operates on the last two axes, whereas np.diagonal
+ # operates on the first two axes by default
+ return Array._new(np.diagonal(x._array, offset=offset, axis1=-2, axis2=-1))
+
+
+# Note: the keyword argument name upper is different from np.linalg.eigh
+def eigh(x: Array, /) -> EighResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.eigh <numpy.linalg.eigh>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.eigh.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in eigh')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.eigh, which only returns a tuple.
+ return EighResult(*map(Array._new, np.linalg.eigh(x._array)))
+
+
+# Note: the keyword argument name upper is different from np.linalg.eigvalsh
+def eigvalsh(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.eigvalsh <numpy.linalg.eigvalsh>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.eigvalsh.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in eigvalsh')
+
+ return Array._new(np.linalg.eigvalsh(x._array))
+
+def inv(x: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.inv <numpy.linalg.inv>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.inv.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in inv')
+
+ return Array._new(np.linalg.inv(x._array))
+
+
+# Note: matmul is the numpy top-level namespace but not in np.linalg
+def matmul(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.matmul <numpy.matmul>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to numeric dtypes only is different from
+ # np.matmul.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in matmul')
+
+ return Array._new(np.matmul(x1._array, x2._array))
+
+
+# Note: the name here is different from norm(). The array API norm is split
+# into matrix_norm and vector_norm().
+
+# The type for ord should be Optional[Union[int, float, Literal[np.inf,
+# -np.inf, 'fro', 'nuc']]], but Literal does not support floating-point
+# literals.
+def matrix_norm(x: Array, /, *, keepdims: bool = False, ord: Optional[Union[int, float, Literal['fro', 'nuc']]] = 'fro') -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.norm.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in matrix_norm')
+
+ return Array._new(np.linalg.norm(x._array, axis=(-2, -1), keepdims=keepdims, ord=ord))
+
+
+def matrix_power(x: Array, n: int, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.matrix_power <numpy.matrix_power>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.matrix_power.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed for the first argument of matrix_power')
+
+ # np.matrix_power already checks if n is an integer
+ return Array._new(np.linalg.matrix_power(x._array, n))
+
+# Note: the keyword argument name rtol is different from np.linalg.matrix_rank
+def matrix_rank(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.matrix_rank <numpy.matrix_rank>`.
+
+ See its docstring for more information.
+ """
+ # Note: this is different from np.linalg.matrix_rank, which supports 1
+ # dimensional arrays.
+ if x.ndim < 2:
+ raise np.linalg.LinAlgError("1-dimensional array given. Array must be at least two-dimensional")
+ S = np.linalg.svd(x._array, compute_uv=False)
+ if rtol is None:
+ tol = S.max(axis=-1, keepdims=True) * max(x.shape[-2:]) * np.finfo(S.dtype).eps
+ else:
+ if isinstance(rtol, Array):
+ rtol = rtol._array
+ # Note: this is different from np.linalg.matrix_rank, which does not multiply
+ # the tolerance by the largest singular value.
+ tol = S.max(axis=-1, keepdims=True)*np.asarray(rtol)[..., np.newaxis]
+ return Array._new(np.count_nonzero(S > tol, axis=-1))
+
+
+# Note: this function is new in the array API spec. Unlike transpose, it only
+# transposes the last two axes.
+def matrix_transpose(x: Array, /) -> Array:
+ if x.ndim < 2:
+ raise ValueError("x must be at least 2-dimensional for matrix_transpose")
+ return Array._new(np.swapaxes(x._array, -1, -2))
+
+# Note: outer is the numpy top-level namespace, not np.linalg
+def outer(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.outer <numpy.outer>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to numeric dtypes only is different from
+ # np.outer.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in outer')
+
+ # Note: the restriction to only 1-dim arrays is different from np.outer
+ if x1.ndim != 1 or x2.ndim != 1:
+ raise ValueError('The input arrays to outer must be 1-dimensional')
+
+ return Array._new(np.outer(x1._array, x2._array))
+
+# Note: the keyword argument name rtol is different from np.linalg.pinv
+def pinv(x: Array, /, *, rtol: Optional[Union[float, Array]] = None) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.pinv <numpy.linalg.pinv>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.pinv.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in pinv')
+
+ # Note: this is different from np.linalg.pinv, which does not multiply the
+ # default tolerance by max(M, N).
+ if rtol is None:
+ rtol = max(x.shape[-2:]) * np.finfo(x.dtype).eps
+ return Array._new(np.linalg.pinv(x._array, rcond=rtol))
+
+def qr(x: Array, /, *, mode: Literal['reduced', 'complete'] = 'reduced') -> QRResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.qr <numpy.linalg.qr>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.qr.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in qr')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.linalg.qr, which only returns a tuple.
+ return QRResult(*map(Array._new, np.linalg.qr(x._array, mode=mode)))
+
+def slogdet(x: Array, /) -> SlogdetResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.slogdet <numpy.linalg.slogdet>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.slogdet.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in slogdet')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.linalg.slogdet, which only returns a tuple.
+ return SlogdetResult(*map(Array._new, np.linalg.slogdet(x._array)))
+
+# Note: unlike np.linalg.solve, the array API solve() only accepts x2 as a
+# vector when it is exactly 1-dimensional. All other cases treat x2 as a stack
+# of matrices. The np.linalg.solve behavior of allowing stacks of both
+# matrices and vectors is ambiguous c.f.
+# https://github.com/numpy/numpy/issues/15349 and
+# https://github.com/data-apis/array-api/issues/285.
+
+# To workaround this, the below is the code from np.linalg.solve except
+# only calling solve1 in the exactly 1D case.
+def _solve(a, b):
+ from ..linalg.linalg import (_makearray, _assert_stacked_2d,
+ _assert_stacked_square, _commonType,
+ isComplexType, get_linalg_error_extobj,
+ _raise_linalgerror_singular)
+ from ..linalg import _umath_linalg
+
+ a, _ = _makearray(a)
+ _assert_stacked_2d(a)
+ _assert_stacked_square(a)
+ b, wrap = _makearray(b)
+ t, result_t = _commonType(a, b)
+
+ # This part is different from np.linalg.solve
+ if b.ndim == 1:
+ gufunc = _umath_linalg.solve1
+ else:
+ gufunc = _umath_linalg.solve
+
+ # This does nothing currently but is left in because it will be relevant
+ # when complex dtype support is added to the spec in 2022.
+ signature = 'DD->D' if isComplexType(t) else 'dd->d'
+ extobj = get_linalg_error_extobj(_raise_linalgerror_singular)
+ r = gufunc(a, b, signature=signature, extobj=extobj)
+
+ return wrap(r.astype(result_t, copy=False))
+
+def solve(x1: Array, x2: Array, /) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.solve <numpy.linalg.solve>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.solve.
+ if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in solve')
+
+ return Array._new(_solve(x1._array, x2._array))
+
+def svd(x: Array, /, *, full_matrices: bool = True) -> SVDResult:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.svd <numpy.linalg.svd>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.svd.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in svd')
+
+ # Note: the return type here is a namedtuple, which is different from
+ # np.svd, which only returns a tuple.
+ return SVDResult(*map(Array._new, np.linalg.svd(x._array, full_matrices=full_matrices)))
+
+# Note: svdvals is not in NumPy (but it is in SciPy). It is equivalent to
+# np.linalg.svd(compute_uv=False).
+def svdvals(x: Array, /) -> Union[Array, Tuple[Array, ...]]:
+ return Array._new(np.linalg.svd(x._array, compute_uv=False))
+
+# Note: tensordot is the numpy top-level namespace but not in np.linalg
+
+# Note: axes must be a tuple, unlike np.tensordot where it can be an array or array-like.
+def tensordot(x1: Array, x2: Array, /, *, axes: Union[int, Tuple[Sequence[int], Sequence[int]]] = 2) -> Array:
+ # Note: the restriction to numeric dtypes only is different from
+ # np.tensordot.
+ if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
+ raise TypeError('Only numeric dtypes are allowed in tensordot')
+
+ return Array._new(np.tensordot(x1._array, x2._array, axes=axes))
+
+# Note: trace is the numpy top-level namespace, not np.linalg
+def trace(x: Array, /, *, offset: int = 0) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.trace <numpy.trace>`.
+
+ See its docstring for more information.
+ """
+ # Note: trace always operates on the last two axes, whereas np.trace
+ # operates on the first two axes by default
+ return Array._new(np.asarray(np.trace(x._array, offset=offset, axis1=-2, axis2=-1)))
+
+# Note: vecdot is not in NumPy
+def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array:
+ return tensordot(x1, x2, axes=((axis,), (axis,)))
+
+
+# Note: the name here is different from norm(). The array API norm is split
+# into matrix_norm and vector_norm().
+
+# The type for ord should be Optional[Union[int, float, Literal[np.inf,
+# -np.inf]]] but Literal does not support floating-point literals.
+def vector_norm(x: Array, /, *, axis: Optional[Union[int, Tuple[int, int]]] = None, keepdims: bool = False, ord: Optional[Union[int, float]] = 2) -> Array:
+ """
+ Array API compatible wrapper for :py:func:`np.linalg.norm <numpy.linalg.norm>`.
+
+ See its docstring for more information.
+ """
+ # Note: the restriction to floating-point dtypes only is different from
+ # np.linalg.norm.
+ if x.dtype not in _floating_dtypes:
+ raise TypeError('Only floating-point dtypes are allowed in norm')
+
+ a = x._array
+ if axis is None:
+ a = a.flatten()
+ axis = 0
+ elif isinstance(axis, tuple):
+ # Note: The axis argument supports any number of axes, whereas norm()
+ # only supports a single axis for vector norm.
+ rest = tuple(i for i in range(a.ndim) if i not in axis)
+ newshape = axis + rest
+ a = np.transpose(a, newshape).reshape((np.prod([a.shape[i] for i in axis]), *[a.shape[i] for i in rest]))
+ axis = 0
+ return Array._new(np.linalg.norm(a, axis=axis, keepdims=keepdims, ord=ord))
+
+
+__all__ = ['cholesky', 'cross', 'det', 'diagonal', 'eigh', 'eigvalsh', 'inv', 'matmul', 'matrix_norm', 'matrix_power', 'matrix_rank', 'matrix_transpose', 'outer', 'pinv', 'qr', 'slogdet', 'solve', 'svd', 'svdvals', 'tensordot', 'trace', 'vecdot', 'vector_norm']