summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorBas van Beek <43369155+BvB93@users.noreply.github.com>2021-09-02 16:09:41 +0200
committerBas van Beek <43369155+BvB93@users.noreply.github.com>2021-09-02 16:28:16 +0200
commit0275d624cf2d980faca5829c2a4147cf4c8becf7 (patch)
tree571263865c70e75af1928868288ccd6c400103d9 /numpy
parent3f6aaa9cb625af3f7ff2106c29afce1e7f9247cc (diff)
downloadnumpy-0275d624cf2d980faca5829c2a4147cf4c8becf7.tar.gz
ENH: Use custom file-like protocols instead of `typing.IO`
Diffstat (limited to 'numpy')
-rw-r--r--numpy/__init__.pyi13
-rw-r--r--numpy/core/multiarray.pyi9
-rw-r--r--numpy/lib/npyio.pyi49
3 files changed, 49 insertions, 22 deletions
diff --git a/numpy/__init__.pyi b/numpy/__init__.pyi
index ca13cffb8..2e7d90d68 100644
--- a/numpy/__init__.pyi
+++ b/numpy/__init__.pyi
@@ -621,6 +621,14 @@ from numpy.matrixlib import (
bmat as bmat,
)
+# Protocol for representing file-like-objects accepted
+# by `ndarray.tofile` and `fromfile`
+class _IOProtocol(Protocol):
+ def flush(self) -> object: ...
+ def fileno(self) -> int: ...
+ def tell(self) -> SupportsIndex: ...
+ def seek(self, offset: int, whence: int, /) -> object: ...
+
__all__: List[str]
__path__: List[str]
__version__: str
@@ -1225,7 +1233,10 @@ class _ArrayOrScalarCommon:
# NOTE: `tostring()` is deprecated and therefore excluded
# def tostring(self, order=...): ...
def tofile(
- self, fid: Union[IO[bytes], str, bytes, os.PathLike[Any]], sep: str = ..., format: str = ...
+ self,
+ fid: str | bytes | os.PathLike[str] | os.PathLike[bytes] | _IOProtocol,
+ sep: str = ...,
+ format: str = ...,
) -> None: ...
# generics and 0d arrays return builtin scalars
def tolist(self) -> Any: ...
diff --git a/numpy/core/multiarray.pyi b/numpy/core/multiarray.pyi
index 3e2873cb3..cad6047c9 100644
--- a/numpy/core/multiarray.pyi
+++ b/numpy/core/multiarray.pyi
@@ -6,7 +6,6 @@ from typing import (
Literal as L,
Any,
Callable,
- IO,
Iterable,
Optional,
overload,
@@ -19,6 +18,7 @@ from typing import (
SupportsIndex,
final,
Final,
+ Protocol,
)
from numpy import (
@@ -50,6 +50,7 @@ from numpy import (
_CastingKind,
_ModeKind,
_SupportsBuffer,
+ _IOProtocol,
)
from numpy.typing import (
@@ -642,7 +643,7 @@ def frompyfunc(
@overload
def fromfile(
- file: str | bytes | os.PathLike[Any] | IO[Any],
+ file: str | bytes | os.PathLike[Any] | _IOProtocol,
dtype: None = ...,
count: SupportsIndex = ...,
sep: str = ...,
@@ -652,7 +653,7 @@ def fromfile(
) -> NDArray[float64]: ...
@overload
def fromfile(
- file: str | bytes | os.PathLike[Any] | IO[Any],
+ file: str | bytes | os.PathLike[Any] | _IOProtocol,
dtype: _DTypeLike[_SCT],
count: SupportsIndex = ...,
sep: str = ...,
@@ -662,7 +663,7 @@ def fromfile(
) -> NDArray[_SCT]: ...
@overload
def fromfile(
- file: str | bytes | os.PathLike[Any] | IO[Any],
+ file: str | bytes | os.PathLike[Any] | _IOProtocol,
dtype: DTypeLike,
count: SupportsIndex = ...,
sep: str = ...,
diff --git a/numpy/lib/npyio.pyi b/numpy/lib/npyio.pyi
index 1fa689bbe..edf3daf07 100644
--- a/numpy/lib/npyio.pyi
+++ b/numpy/lib/npyio.pyi
@@ -18,6 +18,7 @@ from typing import (
Callable,
Pattern,
Protocol,
+ Iterable,
)
from numpy import (
@@ -42,6 +43,8 @@ _T = TypeVar("_T")
_T_contra = TypeVar("_T_contra", contravariant=True)
_T_co = TypeVar("_T_co", covariant=True)
_SCT = TypeVar("_SCT", bound=generic)
+_CharType_co = TypeVar("_CharType_co", str, bytes, covariant=True)
+_CharType_contra = TypeVar("_CharType_contra", str, bytes, contravariant=True)
_DTypeLike = Union[
Type[_SCT],
@@ -52,6 +55,16 @@ _DTypeLike = Union[
class _SupportsGetItem(Protocol[_T_contra, _T_co]):
def __getitem__(self, key: _T_contra) -> _T_co: ...
+class _SupportsRead(Protocol[_CharType_co]):
+ def read(self) -> _CharType_co: ...
+
+class _SupportsReadSeek(Protocol[_CharType_co]):
+ def read(self, n: int, /) -> _CharType_co: ...
+ def seek(self, offset: int, whence: int, /) -> object: ...
+
+class _SupportsWrite(Protocol[_CharType_contra]):
+ def write(self, s: _CharType_contra, /) -> object: ...
+
__all__: List[str]
class BagObj(Generic[_T_co]):
@@ -94,7 +107,7 @@ class NpzFile(Mapping[str, NDArray[Any]]):
# NOTE: Returns a `NpzFile` if file is a zip file;
# returns an `ndarray`/`memmap` otherwise
def load(
- file: str | bytes | os.PathLike[Any] | IO[bytes],
+ file: str | bytes | os.PathLike[Any] | _SupportsReadSeek[bytes],
mmap_mode: L[None, "r+", "r", "w+", "c"] = ...,
allow_pickle: bool = ...,
fix_imports: bool = ...,
@@ -102,27 +115,29 @@ def load(
) -> Any: ...
def save(
- file: str | os.PathLike[str] | IO[bytes],
+ file: str | os.PathLike[str] | _SupportsWrite[bytes],
arr: ArrayLike,
allow_pickle: bool = ...,
fix_imports: bool = ...,
) -> None: ...
def savez(
- file: str | os.PathLike[str] | IO[bytes],
+ file: str | os.PathLike[str] | _SupportsWrite[bytes],
*args: ArrayLike,
**kwds: ArrayLike,
) -> None: ...
def savez_compressed(
- file: str | os.PathLike[str] | IO[bytes],
+ file: str | os.PathLike[str] | _SupportsWrite[bytes],
*args: ArrayLike,
**kwds: ArrayLike,
) -> None: ...
+# File-like objects only have to implement `__iter__` and,
+# optionally, `encoding`
@overload
def loadtxt(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
dtype: None = ...,
comments: str | Sequence[str] = ...,
delimiter: None | str = ...,
@@ -138,7 +153,7 @@ def loadtxt(
) -> NDArray[float64]: ...
@overload
def loadtxt(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
dtype: _DTypeLike[_SCT],
comments: str | Sequence[str] = ...,
delimiter: None | str = ...,
@@ -154,7 +169,7 @@ def loadtxt(
) -> NDArray[_SCT]: ...
@overload
def loadtxt(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
dtype: DTypeLike,
comments: str | Sequence[str] = ...,
delimiter: None | str = ...,
@@ -170,7 +185,7 @@ def loadtxt(
) -> NDArray[Any]: ...
def savetxt(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | _SupportsWrite[str] | _SupportsWrite[bytes],
X: ArrayLike,
fmt: str | Sequence[str] = ...,
delimiter: str = ...,
@@ -183,14 +198,14 @@ def savetxt(
@overload
def fromregex(
- file: str | os.PathLike[str] | IO[Any],
+ file: str | os.PathLike[str] | _SupportsRead[str] | _SupportsRead[bytes],
regexp: str | bytes | Pattern[Any],
dtype: _DTypeLike[_SCT],
encoding: None | str = ...
) -> NDArray[_SCT]: ...
@overload
def fromregex(
- file: str | os.PathLike[str] | IO[Any],
+ file: str | os.PathLike[str] | _SupportsRead[str] | _SupportsRead[bytes],
regexp: str | bytes | Pattern[Any],
dtype: DTypeLike,
encoding: None | str = ...
@@ -199,21 +214,21 @@ def fromregex(
# TODO: Sort out arguments
@overload
def genfromtxt(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
dtype: None = ...,
*args: Any,
**kwargs: Any,
) -> NDArray[float64]: ...
@overload
def genfromtxt(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
dtype: _DTypeLike[_SCT],
*args: Any,
**kwargs: Any,
) -> NDArray[_SCT]: ...
@overload
def genfromtxt(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
dtype: DTypeLike,
*args: Any,
**kwargs: Any,
@@ -221,14 +236,14 @@ def genfromtxt(
@overload
def recfromtxt(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
*,
usemask: L[False] = ...,
**kwargs: Any,
) -> recarray[Any, dtype[void]]: ...
@overload
def recfromtxt(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
*,
usemask: L[True],
**kwargs: Any,
@@ -236,14 +251,14 @@ def recfromtxt(
@overload
def recfromcsv(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
*,
usemask: L[False] = ...,
**kwargs: Any,
) -> recarray[Any, dtype[void]]: ...
@overload
def recfromcsv(
- fname: str | os.PathLike[str] | IO[Any],
+ fname: str | os.PathLike[str] | Iterable[str] | Iterable[bytes],
*,
usemask: L[True],
**kwargs: Any,