summaryrefslogtreecommitdiff
path: root/numpy/typing/_array_like.py
blob: ef6c061d1aa326a2254d8db89e39bc57deb75c37 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import annotations

import sys
from typing import Any, overload, Sequence, TYPE_CHECKING, Union, TypeVar

from numpy import (
    ndarray,
    dtype,
    generic,
    bool_,
    unsignedinteger,
    integer,
    floating,
    complexfloating,
    number,
    timedelta64,
    datetime64,
    object_,
    void,
    str_,
    bytes_,
)
from ._dtype_like import DTypeLike

if sys.version_info >= (3, 8):
    from typing import Protocol
    HAVE_PROTOCOL = True
else:
    try:
        from typing_extensions import Protocol
    except ImportError:
        HAVE_PROTOCOL = False
    else:
        HAVE_PROTOCOL = True

_T = TypeVar("_T")
_ScalarType = TypeVar("_ScalarType", bound=generic)
_DType = TypeVar("_DType", bound="dtype[Any]")
_DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]")

if TYPE_CHECKING or HAVE_PROTOCOL:
    # The `_SupportsArray` protocol only cares about the default dtype
    # (i.e. `dtype=None`) of the to-be returned array.
    # Concrete implementations of the protocol are responsible for adding
    # any and all remaining overloads
    class _SupportsArray(Protocol[_DType_co]):
        def __array__(self, dtype: None = ...) -> ndarray[Any, _DType_co]: ...
else:
    _SupportsArray = Any

# TODO: Wait for support for recursive types
_NestedSequence = Union[
    _T,
    Sequence[_T],
    Sequence[Sequence[_T]],
    Sequence[Sequence[Sequence[_T]]],
    Sequence[Sequence[Sequence[Sequence[_T]]]],
]
_RecursiveSequence = Sequence[Sequence[Sequence[Sequence[Sequence[Any]]]]]

# A union representing array-like objects; consists of two typevars:
# One representing types that can be parametrized w.r.t. `np.dtype`
# and another one for the rest
_ArrayLike = Union[
    _NestedSequence[_SupportsArray[_DType]],
    _NestedSequence[_T],
]

# TODO: support buffer protocols once
#
# https://bugs.python.org/issue27501
#
# is resolved. See also the mypy issue:
#
# https://github.com/python/typing/issues/593
ArrayLike = Union[
    _RecursiveSequence,
    _ArrayLike[
        "dtype[Any]",
        Union[bool, int, float, complex, str, bytes]
    ],
]

# `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
# given the casting rules `same_kind`
_ArrayLikeBool_co = _ArrayLike[
    "dtype[bool_]",
    bool,
]
_ArrayLikeUInt_co = _ArrayLike[
    "dtype[Union[bool_, unsignedinteger[Any]]]",
    bool,
]
_ArrayLikeInt_co = _ArrayLike[
    "dtype[Union[bool_, integer[Any]]]",
    Union[bool, int],
]
_ArrayLikeFloat_co = _ArrayLike[
    "dtype[Union[bool_, integer[Any], floating[Any]]]",
    Union[bool, int, float],
]
_ArrayLikeComplex_co = _ArrayLike[
    "dtype[Union[bool_, integer[Any], floating[Any], complexfloating[Any, Any]]]",
    Union[bool, int, float, complex],
]
_ArrayLikeNumber_co = _ArrayLike[
    "dtype[Union[bool_, number[Any]]]",
    Union[bool, int, float, complex],
]
_ArrayLikeTD64_co = _ArrayLike[
    "dtype[Union[bool_, integer[Any], timedelta64]]",
    Union[bool, int],
]
_ArrayLikeDT64_co = _NestedSequence[_SupportsArray["dtype[datetime64]"]]
_ArrayLikeObject_co = _NestedSequence[_SupportsArray["dtype[object_]"]]

_ArrayLikeVoid_co = _NestedSequence[_SupportsArray["dtype[void]"]]
_ArrayLikeStr_co = _ArrayLike[
    "dtype[str_]",
    str,
]
_ArrayLikeBytes_co = _ArrayLike[
    "dtype[bytes_]",
    bytes,
]

_ArrayLikeInt = _ArrayLike[
    "dtype[integer[Any]]",
    int,
]

if TYPE_CHECKING:
    _ArrayND = ndarray[Any, dtype[_ScalarType]]
    _ArrayOrScalar = Union[_ScalarType, _ArrayND[_ScalarType]]
else:
    _ArrayND = Any
    _ArrayOrScalar = Any