diff options
Diffstat (limited to 'numpy/lib/_datasource.py')
-rw-r--r-- | numpy/lib/_datasource.py | 195 |
1 files changed, 163 insertions, 32 deletions
diff --git a/numpy/lib/_datasource.py b/numpy/lib/_datasource.py index c528de608..816f7624e 100644 --- a/numpy/lib/_datasource.py +++ b/numpy/lib/_datasource.py @@ -15,36 +15,129 @@ DataSource files can originate locally or remotely: - URLs (http, ftp, ...) : 'http://www.scipy.org/not/real/data.txt' DataSource files can also be compressed or uncompressed. Currently only -gzip and bz2 are supported. +gzip, bz2 and xz are supported. Example:: >>> # Create a DataSource, use os.curdir (default) for local storage. - >>> ds = datasource.DataSource() + >>> from numpy import DataSource + >>> ds = DataSource() >>> >>> # Open a remote file. >>> # DataSource downloads the file, stores it locally in: >>> # './www.google.com/index.html' >>> # opens the file and returns a file object. - >>> fp = ds.open('http://www.google.com/index.html') + >>> fp = ds.open('http://www.google.com/') # doctest: +SKIP >>> >>> # Use the file as you normally would - >>> fp.read() - >>> fp.close() + >>> fp.read() # doctest: +SKIP + >>> fp.close() # doctest: +SKIP """ from __future__ import division, absolute_import, print_function import os import sys +import warnings import shutil +import io + +from numpy.core.overrides import set_module + _open = open +def _check_mode(mode, encoding, newline): + """Check mode and that encoding and newline are compatible. + + Parameters + ---------- + mode : str + File open mode. + encoding : str + File encoding. + newline : str + Newline for text files. + + """ + if "t" in mode: + if "b" in mode: + raise ValueError("Invalid mode: %r" % (mode,)) + else: + if encoding is not None: + raise ValueError("Argument 'encoding' not supported in binary mode") + if newline is not None: + raise ValueError("Argument 'newline' not supported in binary mode") + + +def _python2_bz2open(fn, mode, encoding, newline): + """Wrapper to open bz2 in text mode. + + Parameters + ---------- + fn : str + File name + mode : {'r', 'w'} + File mode. Note that bz2 Text files are not supported. + encoding : str + Ignored, text bz2 files not supported in Python2. + newline : str + Ignored, text bz2 files not supported in Python2. + """ + import bz2 + + _check_mode(mode, encoding, newline) + + if "t" in mode: + # BZ2File is missing necessary functions for TextIOWrapper + warnings.warn("Assuming latin1 encoding for bz2 text file in Python2", + RuntimeWarning, stacklevel=5) + mode = mode.replace("t", "") + return bz2.BZ2File(fn, mode) + +def _python2_gzipopen(fn, mode, encoding, newline): + """ Wrapper to open gzip in text mode. + + Parameters + ---------- + fn : str, bytes, file + File path or opened file. + mode : str + File mode. The actual files are opened as binary, but will decoded + using the specified `encoding` and `newline`. + encoding : str + Encoding to be used when reading/writing as text. + newline : str + Newline to be used when reading/writing as text. + + """ + import gzip + # gzip is lacking read1 needed for TextIOWrapper + class GzipWrap(gzip.GzipFile): + def read1(self, n): + return self.read(n) + + _check_mode(mode, encoding, newline) + + gz_mode = mode.replace("t", "") + + if isinstance(fn, (str, bytes)): + binary_file = GzipWrap(fn, gz_mode) + elif hasattr(fn, "read") or hasattr(fn, "write"): + binary_file = GzipWrap(None, gz_mode, fileobj=fn) + else: + raise TypeError("filename must be a str or bytes object, or a file") + + if "t" in mode: + return io.TextIOWrapper(binary_file, encoding, newline=newline) + else: + return binary_file + + # Using a class instead of a module-level dictionary -# to reduce the inital 'import numpy' overhead by -# deferring the import of bz2 and gzip until needed +# to reduce the initial 'import numpy' overhead by +# deferring the import of lzma, bz2 and gzip until needed # TODO: .zip support, .tar support? class _FileOpeners(object): @@ -55,7 +148,7 @@ class _FileOpeners(object): supported file format. Attribute lookup is implemented in such a way that an instance of `_FileOpeners` itself can be indexed with the keys of that dictionary. Currently uncompressed files as well as files - compressed with ``gzip`` or ``bz2`` compression are supported. + compressed with ``gzip``, ``bz2`` or ``xz`` compression are supported. Notes ----- @@ -64,8 +157,9 @@ class _FileOpeners(object): Examples -------- + >>> import gzip >>> np.lib._datasource._file_openers.keys() - [None, '.bz2', '.gz'] + [None, '.bz2', '.gz', '.xz', '.lzma'] >>> np.lib._datasource._file_openers['.gz'] is gzip.open True @@ -73,21 +167,39 @@ class _FileOpeners(object): def __init__(self): self._loaded = False - self._file_openers = {None: open} + self._file_openers = {None: io.open} def _load(self): if self._loaded: return + try: import bz2 - self._file_openers[".bz2"] = bz2.BZ2File + if sys.version_info[0] >= 3: + self._file_openers[".bz2"] = bz2.open + else: + self._file_openers[".bz2"] = _python2_bz2open except ImportError: pass + try: import gzip - self._file_openers[".gz"] = gzip.open + if sys.version_info[0] >= 3: + self._file_openers[".gz"] = gzip.open + else: + self._file_openers[".gz"] = _python2_gzipopen except ImportError: pass + + try: + import lzma + self._file_openers[".xz"] = lzma.open + self._file_openers[".lzma"] = lzma.open + except (ImportError, AttributeError): + # There are incompatible backports of lzma that do not have the + # lzma.open attribute, so catch that as well as ImportError. + pass + self._loaded = True def keys(self): @@ -102,7 +214,7 @@ class _FileOpeners(object): ------- keys : list The keys are None for uncompressed files and the file extension - strings (i.e. ``'.gz'``, ``'.bz2'``) for supported compression + strings (i.e. ``'.gz'``, ``'.xz'``) for supported compression methods. """ @@ -115,7 +227,7 @@ class _FileOpeners(object): _file_openers = _FileOpeners() -def open(path, mode='r', destpath=os.curdir): +def open(path, mode='r', destpath=os.curdir, encoding=None, newline=None): """ Open `path` with `mode` and return the file object. @@ -134,6 +246,11 @@ def open(path, mode='r', destpath=os.curdir): Path to the directory where the source file gets downloaded to for use. If `destpath` is None, a temporary directory will be created. The default path is the current directory. + encoding : {None, str}, optional + Open text file with given encoding. The default encoding will be + what `io.open` uses. + newline : {None, str}, optional + Newline to use when reading text file. Returns ------- @@ -148,10 +265,11 @@ def open(path, mode='r', destpath=os.curdir): """ ds = DataSource(destpath) - return ds.open(path, mode) + return ds.open(path, mode, encoding=encoding, newline=newline) -class DataSource (object): +@set_module('numpy') +class DataSource(object): """ DataSource(destpath='.') @@ -174,7 +292,7 @@ class DataSource (object): URLs require a scheme string (``http://``) to be used, without it they will fail:: - >>> repos = DataSource() + >>> repos = np.DataSource() >>> repos.exists('www.google.com/index.html') False >>> repos.exists('http://www.google.com/index.html') @@ -186,17 +304,17 @@ class DataSource (object): -------- :: - >>> ds = DataSource('/home/guido') - >>> urlname = 'http://www.google.com/index.html' - >>> gfile = ds.open('http://www.google.com/index.html') # remote file + >>> ds = np.DataSource('/home/guido') + >>> urlname = 'http://www.google.com/' + >>> gfile = ds.open('http://www.google.com/') >>> ds.abspath(urlname) - '/home/guido/www.google.com/site/index.html' + '/home/guido/www.google.com/index.html' - >>> ds = DataSource(None) # use with temporary file + >>> ds = np.DataSource(None) # use with temporary file >>> ds.open('/home/guido/foobar.txt') <open file '/home/guido.foobar.txt', mode 'r' at 0x91d4430> >>> ds.abspath('/home/guido/foobar.txt') - '/tmp/tmpy4pgsP/home/guido/foobar.txt' + '/tmp/.../home/guido/foobar.txt' """ @@ -212,7 +330,7 @@ class DataSource (object): def __del__(self): # Remove temp directories - if self._istmpdest: + if hasattr(self, '_istmpdest') and self._istmpdest: shutil.rmtree(self._destpath) def _iszip(self, filename): @@ -429,6 +547,11 @@ class DataSource (object): is accessible if it exists in either location. """ + + # First test for local path + if os.path.exists(path): + return True + # We import this here because importing urllib2 is slow and # a significant fraction of numpy's total import time. if sys.version_info[0] >= 3: @@ -438,10 +561,6 @@ class DataSource (object): from urllib2 import urlopen from urllib2 import URLError - # Test local path - if os.path.exists(path): - return True - # Test cached url upath = self.abspath(path) if os.path.exists(upath): @@ -458,7 +577,7 @@ class DataSource (object): return False return False - def open(self, path, mode='r'): + def open(self, path, mode='r', encoding=None, newline=None): """ Open and return file-like object. @@ -473,6 +592,11 @@ class DataSource (object): Mode to open `path`. Mode 'r' for reading, 'w' for writing, 'a' to append. Available modes depend on the type of object specified by `path`. Default is 'r'. + encoding : {None, str}, optional + Open text file with given encoding. The default encoding will be + what `io.open` uses. + newline : {None, str}, optional + Newline to use when reading text file. Returns ------- @@ -496,7 +620,8 @@ class DataSource (object): _fname, ext = self._splitzipext(found) if ext == 'bz2': mode.replace("+", "") - return _file_openers[ext](found, mode=mode) + return _file_openers[ext](found, mode=mode, + encoding=encoding, newline=newline) else: raise IOError("%s not found." % path) @@ -619,7 +744,7 @@ class Repository (DataSource): """ return DataSource.exists(self, self._fullpath(path)) - def open(self, path, mode='r'): + def open(self, path, mode='r', encoding=None, newline=None): """ Open and return file-like object prepending Repository base URL. @@ -636,6 +761,11 @@ class Repository (DataSource): Mode to open `path`. Mode 'r' for reading, 'w' for writing, 'a' to append. Available modes depend on the type of object specified by `path`. Default is 'r'. + encoding : {None, str}, optional + Open text file with given encoding. The default encoding will be + what `io.open` uses. + newline : {None, str}, optional + Newline to use when reading text file. Returns ------- @@ -643,7 +773,8 @@ class Repository (DataSource): File object. """ - return DataSource.open(self, self._fullpath(path), mode) + return DataSource.open(self, self._fullpath(path), mode, + encoding=encoding, newline=newline) def listdir(self): """ |