summaryrefslogtreecommitdiff
path: root/numpy/array_api/_creation_functions.py
blob: acf78056a27110b1e529e74ecbecb84abb755b57 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
from __future__ import annotations


from typing import TYPE_CHECKING, List, Optional, Tuple, Union
if TYPE_CHECKING:
    from ._typing import (Array, Device, Dtype, NestedSequence,
                          SupportsDLPack, SupportsBufferProtocol)
    from collections.abc import Sequence
from ._dtypes import _all_dtypes

import numpy as np

def _check_valid_dtype(dtype):
    # Note: Only spelling dtypes as the dtype objects is supported.

    # We use this instead of "dtype in _all_dtypes" because the dtype objects
    # define equality with the sorts of things we want to disallw.
    for d in (None,) + _all_dtypes:
        if dtype is d:
            return
    raise ValueError("dtype must be one of the supported dtypes")

def asarray(obj: Union[Array, bool, int, float, NestedSequence[bool|int|float], SupportsDLPack, SupportsBufferProtocol], /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, copy: Optional[bool] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.asarray <numpy.asarray>`.

    See its docstring for more information.
    """
    # _array_object imports in this file are inside the functions to avoid
    # circular imports
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    if copy is False:
        # Note: copy=False is not yet implemented in np.asarray
        raise NotImplementedError("copy=False is not yet implemented")
    if isinstance(obj, Array) and (dtype is None or obj.dtype == dtype):
        if copy is True:
            return Array._new(np.array(obj._array, copy=True, dtype=dtype))
        return obj
    if dtype is None and isinstance(obj, int) and (obj > 2**64 or obj < -2**63):
        # Give a better error message in this case. NumPy would convert this
        # to an object array. TODO: This won't handle large integers in lists.
        raise OverflowError("Integer out of bounds for array dtypes")
    res = np.asarray(obj, dtype=dtype)
    return Array._new(res)

def arange(start: Union[int, float], /, stop: Optional[Union[int, float]] = None, step: Union[int, float] = 1, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.arange <numpy.arange>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.arange(start, stop=stop, step=step, dtype=dtype))

def empty(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.empty <numpy.empty>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.empty(shape, dtype=dtype))

def empty_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.empty_like <numpy.empty_like>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.empty_like(x._array, dtype=dtype))

def eye(n_rows: int, n_cols: Optional[int] = None, /, *, k: Optional[int] = 0, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.eye <numpy.eye>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.eye(n_rows, M=n_cols, k=k, dtype=dtype))

def from_dlpack(x: object, /) -> Array:
    # Note: dlpack support is not yet implemented on Array
    raise NotImplementedError("DLPack support is not yet implemented")

def full(shape: Union[int, Tuple[int, ...]], fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.full <numpy.full>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    if isinstance(fill_value, Array) and fill_value.ndim == 0:
        fill_value = fill_value._array
    res = np.full(shape, fill_value, dtype=dtype)
    if res.dtype not in _all_dtypes:
        # This will happen if the fill value is not something that NumPy
        # coerces to one of the acceptable dtypes.
        raise TypeError("Invalid input to full")
    return Array._new(res)

def full_like(x: Array, /, fill_value: Union[int, float], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.full_like <numpy.full_like>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    res = np.full_like(x._array, fill_value, dtype=dtype)
    if res.dtype not in _all_dtypes:
        # This will happen if the fill value is not something that NumPy
        # coerces to one of the acceptable dtypes.
        raise TypeError("Invalid input to full_like")
    return Array._new(res)

def linspace(start: Union[int, float], stop: Union[int, float], /, num: int, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None, endpoint: bool = True) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.linspace <numpy.linspace>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.linspace(start, stop, num, dtype=dtype, endpoint=endpoint))

def meshgrid(*arrays: Sequence[Array], indexing: str = 'xy') -> List[Array, ...]:
    """
    Array API compatible wrapper for :py:func:`np.meshgrid <numpy.meshgrid>`.

    See its docstring for more information.
    """
    from ._array_object import Array
    return [Array._new(array) for array in np.meshgrid(*[a._array for a in arrays], indexing=indexing)]

def ones(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.ones <numpy.ones>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.ones(shape, dtype=dtype))

def ones_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.ones_like <numpy.ones_like>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.ones_like(x._array, dtype=dtype))

def zeros(shape: Union[int, Tuple[int, ...]], *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.zeros <numpy.zeros>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.zeros(shape, dtype=dtype))

def zeros_like(x: Array, /, *, dtype: Optional[Dtype] = None, device: Optional[Device] = None) -> Array:
    """
    Array API compatible wrapper for :py:func:`np.zeros_like <numpy.zeros_like>`.

    See its docstring for more information.
    """
    from ._array_object import Array

    _check_valid_dtype(dtype)
    if device not in ['cpu', None]:
        raise ValueError(f"Unsupported device {device!r}")
    return Array._new(np.zeros_like(x._array, dtype=dtype))