summaryrefslogtreecommitdiff
path: root/numpy/array_api/_array_object.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/array_api/_array_object.py')
-rw-r--r--numpy/array_api/_array_object.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/numpy/array_api/_array_object.py b/numpy/array_api/_array_object.py
index 50906642d..364b88f89 100644
--- a/numpy/array_api/_array_object.py
+++ b/numpy/array_api/_array_object.py
@@ -19,7 +19,8 @@ import operator
from enum import IntEnum
from ._creation_functions import asarray
from ._dtypes import (_all_dtypes, _boolean_dtypes, _integer_dtypes,
- _integer_or_boolean_dtypes, _floating_dtypes, _numeric_dtypes)
+ _integer_or_boolean_dtypes, _floating_dtypes,
+ _numeric_dtypes, _result_type, _dtype_categories)
from typing import TYPE_CHECKING, Optional, Tuple, Union
if TYPE_CHECKING:
@@ -27,6 +28,8 @@ if TYPE_CHECKING:
import numpy as np
+from numpy import array_api
+
class Array:
"""
n-d array object for the array API namespace.
@@ -98,7 +101,6 @@ class Array:
if other is NotImplemented:
return other
"""
- from ._dtypes import _result_type, _dtype_categories
if self.dtype not in _dtype_categories[dtype_category]:
raise TypeError(f'Only {dtype_category} dtypes are allowed in {op}')
@@ -338,7 +340,6 @@ class Array:
def __array_namespace__(self: Array, /, *, api_version: Optional[str] = None) -> object:
if api_version is not None and not api_version.startswith('2021.'):
raise ValueError(f"Unrecognized array API version: {api_version!r}")
- from numpy import array_api
return array_api
def __bool__(self: Array, /) -> bool: