diff options
Diffstat (limited to 'numpy/array_api')
-rw-r--r-- | numpy/array_api/__init__.py | 4 | ||||
-rw-r--r-- | numpy/array_api/linalg.py | 13 | ||||
-rw-r--r-- | numpy/array_api/tests/test_array_object.py | 2 |
3 files changed, 16 insertions, 3 deletions
diff --git a/numpy/array_api/__init__.py b/numpy/array_api/__init__.py index bbe2fdce2..5e58ee0a8 100644 --- a/numpy/array_api/__init__.py +++ b/numpy/array_api/__init__.py @@ -121,7 +121,9 @@ warnings.warn( "The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2 ) -__all__ = [] +__array_api_version__ = "2021.12" + +__all__ = ["__array_api_version__"] from ._constants import e, inf, nan, pi diff --git a/numpy/array_api/linalg.py b/numpy/array_api/linalg.py index a4a2f23e4..d214046ef 100644 --- a/numpy/array_api/linalg.py +++ b/numpy/array_api/linalg.py @@ -379,7 +379,18 @@ def trace(x: Array, /, *, offset: int = 0) -> Array: def vecdot(x1: Array, x2: Array, /, *, axis: int = -1) -> Array: if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes: raise TypeError('Only numeric dtypes are allowed in vecdot') - return tensordot(x1, x2, axes=((axis,), (axis,))) + ndim = max(x1.ndim, x2.ndim) + x1_shape = (1,)*(ndim - x1.ndim) + tuple(x1.shape) + x2_shape = (1,)*(ndim - x2.ndim) + tuple(x2.shape) + if x1_shape[axis] != x2_shape[axis]: + raise ValueError("x1 and x2 must have the same size along the given axis") + + x1_, x2_ = np.broadcast_arrays(x1._array, x2._array) + x1_ = np.moveaxis(x1_, axis, -1) + x2_ = np.moveaxis(x2_, axis, -1) + + res = x1_[..., None, :] @ x2_[..., None] + return Array._new(res[..., 0, 0]) # Note: the name here is different from norm(). The array API norm is split diff --git a/numpy/array_api/tests/test_array_object.py b/numpy/array_api/tests/test_array_object.py index ba9223532..f6efacefa 100644 --- a/numpy/array_api/tests/test_array_object.py +++ b/numpy/array_api/tests/test_array_object.py @@ -364,7 +364,7 @@ def test_array_keys_use_private_array(): in __getitem__(). This is achieved by passing array_api arrays with 0-sized dimensions, which NumPy-proper treats erroneously - not sure why! - TODO: Find and use appropiate __setitem__() case. + TODO: Find and use appropriate __setitem__() case. """ a = ones((0, 0), dtype=bool_) assert a[a].shape == (0,) |