diff options
-rw-r--r-- | numpy/__init__.pyi | 1 | ||||
-rw-r--r-- | numpy/_globals.py | 2 | ||||
-rw-r--r-- | numpy/array_api/_creation_functions.py | 6 | ||||
-rw-r--r-- | numpy/array_api/tests/test_creation_functions.py | 2 | ||||
-rw-r--r-- | numpy/core/multiarray.pyi | 7 |
5 files changed, 6 insertions, 12 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi index dafedeb56..30b0944fd 100644 --- a/numpy/__init__.pyi +++ b/numpy/__init__.pyi @@ -3689,7 +3689,6 @@ trunc: _UFunc_Nin1_Nout1[L['trunc'], L[7], None] abs = absolute class _CopyMode(enum.Enum): - ALWAYS: L[1] IF_NEEDED: L[0] NEVER: L[2] diff --git a/numpy/_globals.py b/numpy/_globals.py index 133ab11cc..d458fc9c4 100644 --- a/numpy/_globals.py +++ b/numpy/_globals.py @@ -108,7 +108,7 @@ class _CopyMode(enum.Enum): if self == _CopyMode.IF_NEEDED: return False - raise TypeError(f"{self} is neither True nor False.") + raise ValueError(f"{self} is neither True nor False.") _CopyMode.__module__ = 'numpy' diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py index c15d54db1..2d6cf4414 100644 --- a/numpy/array_api/_creation_functions.py +++ b/numpy/array_api/_creation_functions.py @@ -43,7 +43,7 @@ def asarray( *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, - copy: Optional[Union[bool | np._CopyMode]] = None, + copy: Optional[Union[bool, np._CopyMode]] = None, ) -> Array: """ Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`. @@ -57,11 +57,11 @@ def asarray( _check_valid_dtype(dtype) if device not in ["cpu", None]: raise ValueError(f"Unsupported device {device!r}") - if copy is False: + if copy in (False, np._CopyMode.IF_NEEDED): # Note: copy=False is not yet implemented in np.asarray raise NotImplementedError("copy=False is not yet implemented") if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype): - if copy is True: + if copy in (True, np._CopyMode.ALWAYS): return Array._new(np.array(obj._array, copy=True, dtype=dtype)) return obj if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)): diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py index 7182209dc..2ee23a47b 100644 --- a/numpy/array_api/tests/test_creation_functions.py +++ b/numpy/array_api/tests/test_creation_functions.py @@ -67,7 +67,7 @@ def test_asarray_copy(): assert all(b[0] == 0) assert_raises(NotImplementedError, lambda: asarray(a, copy=False)) assert_raises(NotImplementedError, - lambda: asarray(a, copy=np._CopyMode.NEVER)) + lambda: asarray(a, copy=np._CopyMode.IF_NEEDED)) def test_arange_errors(): diff --git a/numpy/core/multiarray.pyi b/numpy/core/multiarray.pyi index 97e9c3498..501e55634 100644 --- a/numpy/core/multiarray.pyi +++ b/numpy/core/multiarray.pyi @@ -51,6 +51,7 @@ from numpy import ( _ModeKind, _SupportsBuffer, _IOProtocol, + _CopyMode ) from numpy.typing import ( @@ -1012,9 +1013,3 @@ class flagsobj: def owndata(self) -> bool: ... def __getitem__(self, key: _GetItemKeys) -> bool: ... def __setitem__(self, key: _SetItemKeys, value: bool) -> None: ... - -class _CopyMode(enum.Enum): - - ALWAYS: L[1] - IF_NEEDED: L[0] - NEVER: L[2]
\ No newline at end of file |