diff options
| author | Aaron Meurer <asmeurer@gmail.com> | 2021-07-16 15:01:09 -0600 |
|---|---|---|
| committer | Aaron Meurer <asmeurer@gmail.com> | 2021-07-16 15:01:09 -0600 |
| commit | 185d06d45cbc21ad06d7719ff265b1132d1c41ff (patch) | |
| tree | 33031ef4df3f846f716d36d5119d8c463c5039ce /numpy/_array_api | |
| parent | 3cab20ee117d39c74abd2a28f142529e379844b1 (diff) | |
| download | numpy-185d06d45cbc21ad06d7719ff265b1132d1c41ff.tar.gz | |
Guard against non-array API inputs in the array API result_type()
Diffstat (limited to 'numpy/_array_api')
| -rw-r--r-- | numpy/_array_api/_data_type_functions.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/numpy/_array_api/_data_type_functions.py b/numpy/_array_api/_data_type_functions.py index fe5c7557f..7e0c35a65 100644 --- a/numpy/_array_api/_data_type_functions.py +++ b/numpy/_array_api/_data_type_functions.py @@ -1,6 +1,7 @@ from __future__ import annotations from ._array_object import Array +from ._dtypes import _all_dtypes from dataclasses import dataclass from typing import TYPE_CHECKING, List, Tuple, Union @@ -93,4 +94,12 @@ def result_type(*arrays_and_dtypes: Sequence[Union[Array, Dtype]]) -> Dtype: See its docstring for more information. """ - return np.result_type(*(a._array if isinstance(a, Array) else a for a in arrays_and_dtypes)) + A = [] + for a in arrays_and_dtypes: + if isinstance(a, Array): + a = a._array + elif isinstance(a, np.ndarray) or a not in _all_dtypes: + raise TypeError("result_type() inputs must be array_api arrays or dtypes") + A.append(a) + + return np.result_type(*A) |
