summaryrefslogtreecommitdiff
path: root/numpy/_array_api
diff options
context:
space:
mode:
authorAaron Meurer <asmeurer@gmail.com>2021-07-16 15:01:09 -0600
committerAaron Meurer <asmeurer@gmail.com>2021-07-16 15:01:09 -0600
commit185d06d45cbc21ad06d7719ff265b1132d1c41ff (patch)
tree33031ef4df3f846f716d36d5119d8c463c5039ce /numpy/_array_api
parent3cab20ee117d39c74abd2a28f142529e379844b1 (diff)
downloadnumpy-185d06d45cbc21ad06d7719ff265b1132d1c41ff.tar.gz
Guard against non-array API inputs in the array API result_type()
Diffstat (limited to 'numpy/_array_api')
-rw-r--r--numpy/_array_api/_data_type_functions.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/numpy/_array_api/_data_type_functions.py b/numpy/_array_api/_data_type_functions.py
index fe5c7557f..7e0c35a65 100644
--- a/numpy/_array_api/_data_type_functions.py
+++ b/numpy/_array_api/_data_type_functions.py
@@ -1,6 +1,7 @@
from __future__ import annotations
from ._array_object import Array
+from ._dtypes import _all_dtypes
from dataclasses import dataclass
from typing import TYPE_CHECKING, List, Tuple, Union
@@ -93,4 +94,12 @@ def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype:
See its docstring for more information.
"""
- return np.result_type(*(a._array if isinstance(a, Array) else a for a in arrays_and_dtypes))
+ A = []
+ for a in arrays_and_dtypes:
+ if isinstance(a, Array):
+ a = a._array
+ elif isinstance(a, np.ndarray) or a not in _all_dtypes:
+ raise TypeError("result_type() inputs must be array_api arrays or dtypes")
+ A.append(a)
+
+ return np.result_type(*A)