summaryrefslogtreecommitdiff
path: root/numpy/typing/_array_like.py
blob: c562f3c1fc30dffbeb6e80d3be7010fd4cb85334 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
from __future__ import annotations

import sys
from typing import Any, Sequence, TYPE_CHECKING, Union, TypeVar, Generic
from numpy import (
    ndarray,
    dtype,
    generic,
    bool_,
    unsignedinteger,
    integer,
    floating,
    complexfloating,
    number,
    timedelta64,
    datetime64,
    object_,
    void,
    str_,
    bytes_,
)
from . import _HAS_TYPING_EXTENSIONS

if sys.version_info >= (3, 8):
    from typing import Protocol
elif _HAS_TYPING_EXTENSIONS:
    from typing_extensions import Protocol

_T = TypeVar("_T")
_ScalarType = TypeVar("_ScalarType", bound=generic)
_DType = TypeVar("_DType", bound="dtype[Any]")
_DType_co = TypeVar("_DType_co", covariant=True, bound="dtype[Any]")

if TYPE_CHECKING or _HAS_TYPING_EXTENSIONS or sys.version_info >= (3, 8):
    # The `_SupportsArray` protocol only cares about the default dtype
    # (i.e. `dtype=None` or no `dtype` parameter at all) of the to-be returned
    # array.
    # Concrete implementations of the protocol are responsible for adding
    # any and all remaining overloads
    class _SupportsArray(Protocol[_DType_co]):
        def __array__(self) -> ndarray[Any, _DType_co]: ...
else:
    class _SupportsArray(Generic[_DType_co]): ...

# TODO: Wait for support for recursive types
_NestedSequence = Union[
    _T,
    Sequence[_T],
    Sequence[Sequence[_T]],
    Sequence[Sequence[Sequence[_T]]],
    Sequence[Sequence[Sequence[Sequence[_T]]]],
]
_RecursiveSequence = Sequence[Sequence[Sequence[Sequence[Sequence[Any]]]]]

# A union representing array-like objects; consists of two typevars:
# One representing types that can be parametrized w.r.t. `np.dtype`
# and another one for the rest
_ArrayLike = Union[
    _NestedSequence[_SupportsArray[_DType]],
    _NestedSequence[_T],
]

# TODO: support buffer protocols once
#
# https://bugs.python.org/issue27501
#
# is resolved. See also the mypy issue:
#
# https://github.com/python/typing/issues/593
ArrayLike = Union[
    _RecursiveSequence,
    _ArrayLike[
        dtype,
        Union[bool, int, float, complex, str, bytes]
    ],
]

# `ArrayLike<X>_co`: array-like objects that can be coerced into `X`
# given the casting rules `same_kind`
_ArrayLikeBool_co = _ArrayLike[
    "dtype[bool_]",
    bool,
]
_ArrayLikeUInt_co = _ArrayLike[
    "dtype[Union[bool_, unsignedinteger[Any]]]",
    bool,
]
_ArrayLikeInt_co = _ArrayLike[
    "dtype[Union[bool_, integer[Any]]]",
    Union[bool, int],
]
_ArrayLikeFloat_co = _ArrayLike[
    "dtype[Union[bool_, integer[Any], floating[Any]]]",
    Union[bool, int, float],
]
_ArrayLikeComplex_co = _ArrayLike[
    "dtype[Union[bool_, integer[Any], floating[Any], complexfloating[Any, Any]]]",
    Union[bool, int, float, complex],
]
_ArrayLikeNumber_co = _ArrayLike[
    "dtype[Union[bool_, number[Any]]]",
    Union[bool, int, float, complex],
]
_ArrayLikeTD64_co = _ArrayLike[
    "dtype[Union[bool_, integer[Any], timedelta64]]",
    Union[bool, int],
]
_ArrayLikeDT64_co = _NestedSequence[_SupportsArray["dtype[datetime64]"]]
_ArrayLikeObject_co = _NestedSequence[_SupportsArray["dtype[object_]"]]

_ArrayLikeVoid_co = _NestedSequence[_SupportsArray["dtype[void]"]]
_ArrayLikeStr_co = _ArrayLike[
    "dtype[str_]",
    str,
]
_ArrayLikeBytes_co = _ArrayLike[
    "dtype[bytes_]",
    bytes,
]

_ArrayLikeInt = _ArrayLike[
    "dtype[integer[Any]]",
    int,
]