summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/__init__.pyi1
-rw-r--r--numpy/_globals.py2
-rw-r--r--numpy/array_api/_creation_functions.py6
-rw-r--r--numpy/array_api/tests/test_creation_functions.py2
-rw-r--r--numpy/core/multiarray.pyi7
5 files changed, 6 insertions, 12 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index dafedeb56..30b0944fd 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -3689,7 +3689,6 @@ trunc: _UFunc_Nin1_Nout1[L['trunc'], L[7], None]
abs = absolute
class _CopyMode(enum.Enum):
-
ALWAYS: L[1]
IF_NEEDED: L[0]
NEVER: L[2]
diff --git a/numpy/_globals.py b/numpy/_globals.py
index 133ab11cc..d458fc9c4 100644
--- a/numpy/_globals.py
+++ b/numpy/_globals.py
@@ -108,7 +108,7 @@ class _CopyMode(enum.Enum):
if self == _CopyMode.IF_NEEDED:
return False
- raise TypeError(f"{self} is neither True nor False.")
+ raise ValueError(f"{self} is neither True nor False.")
_CopyMode.__module__ = 'numpy'
diff --git a/numpy/array_api/_creation_functions.py b/numpy/array_api/_creation_functions.py
index c15d54db1..2d6cf4414 100644
--- a/numpy/array_api/_creation_functions.py
+++ b/numpy/array_api/_creation_functions.py
@@ -43,7 +43,7 @@ def asarray(
*,
dtype: Optional[Dtype] = None,
device: Optional[Device] = None,
- copy: Optional[Union[bool | np._CopyMode]] = None,
+ copy: Optional[Union[bool, np._CopyMode]] = None,
) -> Array:
"""
Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.
@@ -57,11 +57,11 @@ def asarray(
_check_valid_dtype(dtype)
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device {device!r}")
- if copy is False:
+ if copy in (False, np._CopyMode.IF_NEEDED):
# Note: copy=False is not yet implemented in np.asarray
raise NotImplementedError("copy=False is not yet implemented")
if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype):
- if copy is True:
+ if copy in (True, np._CopyMode.ALWAYS):
return Array._new(np.array(obj._array, copy=True, dtype=dtype))
return obj
if dtype is None and isinstance(obj, int) and (obj > 2 ** 64 or obj < -(2 ** 63)):
diff --git a/numpy/array_api/tests/test_creation_functions.py b/numpy/array_api/tests/test_creation_functions.py
index 7182209dc..2ee23a47b 100644
--- a/numpy/array_api/tests/test_creation_functions.py
+++ b/numpy/array_api/tests/test_creation_functions.py
@@ -67,7 +67,7 @@ def test_asarray_copy():
assert all(b[0] == 0)
assert_raises(NotImplementedError, lambda: asarray(a, copy=False))
assert_raises(NotImplementedError,
- lambda: asarray(a, copy=np._CopyMode.NEVER))
+ lambda: asarray(a, copy=np._CopyMode.IF_NEEDED))
def test_arange_errors():
diff --git a/numpy/core/multiarray.pyi b/numpy/core/multiarray.pyi
index 97e9c3498..501e55634 100644
--- a/numpy/core/multiarray.pyi
+++ b/numpy/core/multiarray.pyi
@@ -51,6 +51,7 @@ from numpy import (
_ModeKind,
_SupportsBuffer,
_IOProtocol,
+ _CopyMode
)
from numpy.typing import (
@@ -1012,9 +1013,3 @@ class flagsobj:
def owndata(self) -> bool: ...
def __getitem__(self, key: _GetItemKeys) -> bool: ...
def __setitem__(self, key: _SetItemKeys, value: bool) -> None: ...
-
-class _CopyMode(enum.Enum):
-
- ALWAYS: L[1]
- IF_NEEDED: L[0]
- NEVER: L[2] \ No newline at end of file