diff options
Diffstat (limited to 'Lib/shutil.py')
| -rw-r--r-- | Lib/shutil.py | 463 | 
1 files changed, 404 insertions, 59 deletions
| diff --git a/Lib/shutil.py b/Lib/shutil.py index ef29ae2303..5dc311e70d 100644 --- a/Lib/shutil.py +++ b/Lib/shutil.py @@ -15,6 +15,7 @@ import tarfile  try:      import bz2 +    del bz2      _BZ2_SUPPORTED = True  except ImportError:      _BZ2_SUPPORTED = False @@ -34,7 +35,9 @@ __all__ = ["copyfileobj", "copyfile", "copymode", "copystat", "copy", "copy2",             "ExecError", "make_archive", "get_archive_formats",             "register_archive_format", "unregister_archive_format",             "get_unpack_formats", "register_unpack_format", -           "unregister_unpack_format", "unpack_archive", "ignore_patterns"] +           "unregister_unpack_format", "unpack_archive", +           "ignore_patterns", "chown", "which"] +           # disk_usage is added later, if available on the platform  class Error(EnvironmentError):      pass @@ -79,8 +82,13 @@ def _samefile(src, dst):      return (os.path.normcase(os.path.abspath(src)) ==              os.path.normcase(os.path.abspath(dst))) -def copyfile(src, dst): -    """Copy data from src to dst""" +def copyfile(src, dst, *, follow_symlinks=True): +    """Copy data from src to dst. + +    If follow_symlinks is not set and src is a symbolic link, a new +    symlink will be created instead of copying the file it points to. + +    """      if _samefile(src, dst):          raise Error("`%s` and `%s` are the same file" % (src, dst)) @@ -95,56 +103,140 @@ def copyfile(src, dst):              if stat.S_ISFIFO(st.st_mode):                  raise SpecialFileError("`%s` is a named pipe" % fn) -    with open(src, 'rb') as fsrc: -        with open(dst, 'wb') as fdst: -            copyfileobj(fsrc, fdst) +    if not follow_symlinks and os.path.islink(src): +        os.symlink(os.readlink(src), dst) +    else: +        with open(src, 'rb') as fsrc: +            with open(dst, 'wb') as fdst: +                copyfileobj(fsrc, fdst) +    return dst + +def copymode(src, dst, *, follow_symlinks=True): +    """Copy mode bits from src to dst. + +    If follow_symlinks is not set, symlinks aren't followed if and only +    if both `src` and `dst` are symlinks.  If `lchmod` isn't available +    (e.g. Linux) this method does nothing. + +    """ +    if not follow_symlinks and os.path.islink(src) and os.path.islink(dst): +        if hasattr(os, 'lchmod'): +            stat_func, chmod_func = os.lstat, os.lchmod +        else: +            return +    elif hasattr(os, 'chmod'): +        stat_func, chmod_func = os.stat, os.chmod +    else: +        return + +    st = stat_func(src) +    chmod_func(dst, stat.S_IMODE(st.st_mode)) + +if hasattr(os, 'listxattr'): +    def _copyxattr(src, dst, *, follow_symlinks=True): +        """Copy extended filesystem attributes from `src` to `dst`. + +        Overwrite existing attributes. + +        If `follow_symlinks` is false, symlinks won't be followed. -def copymode(src, dst): -    """Copy mode bits from src to dst""" -    if hasattr(os, 'chmod'): -        st = os.stat(src) -        mode = stat.S_IMODE(st.st_mode) -        os.chmod(dst, mode) +        """ -def copystat(src, dst): -    """Copy all stat info (mode bits, atime, mtime, flags) from src to dst""" -    st = os.stat(src) +        for name in os.listxattr(src, follow_symlinks=follow_symlinks): +            try: +                value = os.getxattr(src, name, follow_symlinks=follow_symlinks) +                os.setxattr(dst, name, value, follow_symlinks=follow_symlinks) +            except OSError as e: +                if e.errno not in (errno.EPERM, errno.ENOTSUP, errno.ENODATA): +                    raise +else: +    def _copyxattr(*args, **kwargs): +        pass + +def copystat(src, dst, *, follow_symlinks=True): +    """Copy all stat info (mode bits, atime, mtime, flags) from src to dst. + +    If the optional flag `follow_symlinks` is not set, symlinks aren't followed if and +    only if both `src` and `dst` are symlinks. + +    """ +    def _nop(*args, ns=None, follow_symlinks=None): +        pass + +    # follow symlinks (aka don't not follow symlinks) +    follow = follow_symlinks or not (os.path.islink(src) and os.path.islink(dst)) +    if follow: +        # use the real function if it exists +        def lookup(name): +            return getattr(os, name, _nop) +    else: +        # use the real function only if it exists +        # *and* it supports follow_symlinks +        def lookup(name): +            fn = getattr(os, name, _nop) +            if fn in os.supports_follow_symlinks: +                return fn +            return _nop + +    st = lookup("stat")(src, follow_symlinks=follow)      mode = stat.S_IMODE(st.st_mode) -    if hasattr(os, 'utime'): -        os.utime(dst, (st.st_atime, st.st_mtime)) -    if hasattr(os, 'chmod'): -        os.chmod(dst, mode) -    if hasattr(os, 'chflags') and hasattr(st, 'st_flags'): +    lookup("utime")(dst, ns=(st.st_atime_ns, st.st_mtime_ns), +        follow_symlinks=follow) +    try: +        lookup("chmod")(dst, mode, follow_symlinks=follow) +    except NotImplementedError: +        # if we got a NotImplementedError, it's because +        #   * follow_symlinks=False, +        #   * lchown() is unavailable, and +        #   * either +        #       * fchownat() is unvailable or +        #       * fchownat() doesn't implement AT_SYMLINK_NOFOLLOW. +        #         (it returned ENOSUP.) +        # therefore we're out of options--we simply cannot chown the +        # symlink.  give up, suppress the error. +        # (which is what shutil always did in this circumstance.) +        pass +    if hasattr(st, 'st_flags'):          try: -            os.chflags(dst, st.st_flags) +            lookup("chflags")(dst, st.st_flags, follow_symlinks=follow)          except OSError as why:              for err in 'EOPNOTSUPP', 'ENOTSUP':                  if hasattr(errno, err) and why.errno == getattr(errno, err):                      break              else:                  raise +    _copyxattr(src, dst, follow_symlinks=follow) -def copy(src, dst): -    """Copy data and mode bits ("cp src dst"). +def copy(src, dst, *, follow_symlinks=True): +    """Copy data and mode bits ("cp src dst"). Return the file's destination.      The destination may be a directory. +    If follow_symlinks is false, symlinks won't be followed. This +    resembles GNU's "cp -P src dst". +      """      if os.path.isdir(dst):          dst = os.path.join(dst, os.path.basename(src)) -    copyfile(src, dst) -    copymode(src, dst) +    copyfile(src, dst, follow_symlinks=follow_symlinks) +    copymode(src, dst, follow_symlinks=follow_symlinks) +    return dst -def copy2(src, dst): -    """Copy data and all stat info ("cp -p src dst"). +def copy2(src, dst, *, follow_symlinks=True): +    """Copy data and all stat info ("cp -p src dst"). Return the file's +    destination."      The destination may be a directory. +    If follow_symlinks is false, symlinks won't be followed. This +    resembles GNU's "cp -P src dst". +      """      if os.path.isdir(dst):          dst = os.path.join(dst, os.path.basename(src)) -    copyfile(src, dst) -    copystat(src, dst) +    copyfile(src, dst, follow_symlinks=follow_symlinks) +    copystat(src, dst, follow_symlinks=follow_symlinks) +    return dst  def ignore_patterns(*patterns):      """Function that can be used as copytree() ignore parameter. @@ -211,7 +303,11 @@ def copytree(src, dst, symlinks=False, ignore=None, copy_function=copy2,              if os.path.islink(srcname):                  linkto = os.readlink(srcname)                  if symlinks: +                    # We can't just leave it to `copy_function` because legacy +                    # code with a custom `copy_function` may rely on copytree +                    # doing the right thing.                      os.symlink(linkto, dstname) +                    copystat(srcname, dstname, follow_symlinks=not symlinks)                  else:                      # ignore dangling symlink if the flag is on                      if not os.path.exists(linkto) and ignore_dangling_symlinks: @@ -239,24 +335,10 @@ def copytree(src, dst, symlinks=False, ignore=None, copy_function=copy2,              errors.append((src, dst, str(why)))      if errors:          raise Error(errors) +    return dst -def rmtree(path, ignore_errors=False, onerror=None): -    """Recursively delete a directory tree. - -    If ignore_errors is set, errors are ignored; otherwise, if onerror -    is set, it is called to handle the error with arguments (func, -    path, exc_info) where func is os.listdir, os.remove, or os.rmdir; -    path is the argument to that function that caused it to fail; and -    exc_info is a tuple returned by sys.exc_info().  If ignore_errors -    is false and onerror is None, an exception is raised. - -    """ -    if ignore_errors: -        def onerror(*args): -            pass -    elif onerror is None: -        def onerror(*args): -            raise +# version vulnerable to race conditions +def _rmtree_unsafe(path, onerror):      try:          if os.path.islink(path):              # symlinks to directories are forbidden, see bug #1669 @@ -268,7 +350,7 @@ def rmtree(path, ignore_errors=False, onerror=None):      names = []      try:          names = os.listdir(path) -    except os.error as err: +    except os.error:          onerror(os.listdir, path, sys.exc_info())      for name in names:          fullname = os.path.join(path, name) @@ -277,17 +359,109 @@ def rmtree(path, ignore_errors=False, onerror=None):          except os.error:              mode = 0          if stat.S_ISDIR(mode): -            rmtree(fullname, ignore_errors, onerror) +            _rmtree_unsafe(fullname, onerror)          else:              try: -                os.remove(fullname) -            except os.error as err: -                onerror(os.remove, fullname, sys.exc_info()) +                os.unlink(fullname) +            except os.error: +                onerror(os.unlink, fullname, sys.exc_info())      try:          os.rmdir(path)      except os.error:          onerror(os.rmdir, path, sys.exc_info()) +# Version using fd-based APIs to protect against races +def _rmtree_safe_fd(topfd, path, onerror): +    names = [] +    try: +        names = os.listdir(topfd) +    except os.error: +        onerror(os.listdir, path, sys.exc_info()) +    for name in names: +        fullname = os.path.join(path, name) +        try: +            orig_st = os.stat(name, dir_fd=topfd, follow_symlinks=False) +            mode = orig_st.st_mode +        except os.error: +            mode = 0 +        if stat.S_ISDIR(mode): +            try: +                dirfd = os.open(name, os.O_RDONLY, dir_fd=topfd) +            except os.error: +                onerror(os.open, fullname, sys.exc_info()) +            else: +                try: +                    if os.path.samestat(orig_st, os.fstat(dirfd)): +                        _rmtree_safe_fd(dirfd, fullname, onerror) +                        try: +                            os.rmdir(name, dir_fd=topfd) +                        except os.error: +                            onerror(os.rmdir, fullname, sys.exc_info()) +                finally: +                    os.close(dirfd) +        else: +            try: +                os.unlink(name, dir_fd=topfd) +            except os.error: +                onerror(os.unlink, fullname, sys.exc_info()) + +_use_fd_functions = ({os.open, os.stat, os.unlink, os.rmdir} <= +                     os.supports_dir_fd and +                     os.listdir in os.supports_fd and +                     os.stat in os.supports_follow_symlinks) + +def rmtree(path, ignore_errors=False, onerror=None): +    """Recursively delete a directory tree. + +    If ignore_errors is set, errors are ignored; otherwise, if onerror +    is set, it is called to handle the error with arguments (func, +    path, exc_info) where func is platform and implementation dependent; +    path is the argument to that function that caused it to fail; and +    exc_info is a tuple returned by sys.exc_info().  If ignore_errors +    is false and onerror is None, an exception is raised. + +    """ +    if ignore_errors: +        def onerror(*args): +            pass +    elif onerror is None: +        def onerror(*args): +            raise +    if _use_fd_functions: +        # While the unsafe rmtree works fine on bytes, the fd based does not. +        if isinstance(path, bytes): +            path = os.fsdecode(path) +        # Note: To guard against symlink races, we use the standard +        # lstat()/open()/fstat() trick. +        try: +            orig_st = os.lstat(path) +        except Exception: +            onerror(os.lstat, path, sys.exc_info()) +            return +        try: +            fd = os.open(path, os.O_RDONLY) +        except Exception: +            onerror(os.lstat, path, sys.exc_info()) +            return +        try: +            if (stat.S_ISDIR(orig_st.st_mode) and +                os.path.samestat(orig_st, os.fstat(fd))): +                _rmtree_safe_fd(fd, path, onerror) +                try: +                    os.rmdir(path) +                except os.error: +                    onerror(os.rmdir, path, sys.exc_info()) +            else: +                raise NotADirectoryError(20, +                                         "Not a directory: '{}'".format(path)) +        finally: +            os.close(fd) +    else: +        return _rmtree_unsafe(path, onerror) + +# Allow introspection of whether or not the hardening against symlink +# attacks is supported on the current platform +rmtree.avoids_symlink_attacks = _use_fd_functions  def _basename(path):      # A basename() variant which first strips the trailing slash, if present. @@ -296,7 +470,8 @@ def _basename(path):  def move(src, dst):      """Recursively move a file or directory to another location. This is -    similar to the Unix "mv" command. +    similar to the Unix "mv" command. Return the file or directory's +    destination.      If the destination is a directory or a symlink to a directory, the source      is moved inside the directory. The destination path must not already @@ -306,7 +481,10 @@ def move(src, dst):      overwritten depending on os.rename() semantics.      If the destination is on our current filesystem, then rename() is used. -    Otherwise, src is copied to the destination and then removed. +    Otherwise, src is copied to the destination and then removed. Symlinks are +    recreated under the new name if os.rename() fails because of cross +    filesystem renames. +      A lot more could be done here...  A look at a mv.c shows a lot of      the issues this implementation glosses over. @@ -324,8 +502,12 @@ def move(src, dst):              raise Error("Destination path '%s' already exists" % real_dst)      try:          os.rename(src, real_dst) -    except OSError as exc: -        if os.path.isdir(src): +    except OSError: +        if os.path.islink(src): +            linkto = os.readlink(src) +            os.symlink(linkto, real_dst) +            os.unlink(src) +        elif os.path.isdir(src):              if _destinsrc(src, dst):                  raise Error("Cannot move a directory '%s' into itself '%s'." % (src, dst))              copytree(src, real_dst, symlinks=True) @@ -333,6 +515,7 @@ def move(src, dst):          else:              copy2(src, real_dst)              os.unlink(src) +    return real_dst  def _destinsrc(src, dst):      src = abspath(src) @@ -391,7 +574,7 @@ def _make_tarball(base_name, base_dir, compress="gzip", verbose=0, dry_run=0,          compress_ext['bzip2'] = '.bz2'      # flags for compression program, each element of list will be an argument -    if compress is not None and compress not in compress_ext.keys(): +    if compress is not None and compress not in compress_ext:          raise ValueError("bad value for 'compress', or compression format not "                           "supported : {0}".format(compress)) @@ -496,7 +679,7 @@ def _make_zipfile(base_name, base_dir, verbose=0, dry_run=0, logger=None):  _ARCHIVE_FORMATS = {      'gztar': (_make_tarball, [('compress', 'gzip')], "gzip'ed tar-file"),      'tar':   (_make_tarball, [('compress', None)], "uncompressed tar file"), -    'zip':   (_make_zipfile, [],"ZIP file") +    'zip':   (_make_zipfile, [], "ZIP file")      }  if _BZ2_SUPPORTED: @@ -529,7 +712,7 @@ def register_archive_format(name, function, extra_args=None, description=''):      if not isinstance(extra_args, (tuple, list)):          raise TypeError('extra_args needs to be a sequence')      for element in extra_args: -        if not isinstance(element, (tuple, list)) or len(element) !=2 : +        if not isinstance(element, (tuple, list)) or len(element) !=2:              raise TypeError('extra_args elements are : (arg_name, value)')      _ARCHIVE_FORMATS[name] = (function, extra_args, description) @@ -681,7 +864,7 @@ def _unpack_zipfile(filename, extract_dir):              if not name.endswith('/'):                  # file                  data = zip.read(info.filename) -                f = open(target,'wb') +                f = open(target, 'wb')                  try:                      f.write(data)                  finally: @@ -755,3 +938,165 @@ def unpack_archive(filename, extract_dir=None, format=None):          func = _UNPACK_FORMATS[format][1]          kwargs = dict(_UNPACK_FORMATS[format][2])          func(filename, extract_dir, **kwargs) + + +if hasattr(os, 'statvfs'): + +    __all__.append('disk_usage') +    _ntuple_diskusage = collections.namedtuple('usage', 'total used free') + +    def disk_usage(path): +        """Return disk usage statistics about the given path. + +        Returned value is a named tuple with attributes 'total', 'used' and +        'free', which are the amount of total, used and free space, in bytes. +        """ +        st = os.statvfs(path) +        free = st.f_bavail * st.f_frsize +        total = st.f_blocks * st.f_frsize +        used = (st.f_blocks - st.f_bfree) * st.f_frsize +        return _ntuple_diskusage(total, used, free) + +elif os.name == 'nt': + +    import nt +    __all__.append('disk_usage') +    _ntuple_diskusage = collections.namedtuple('usage', 'total used free') + +    def disk_usage(path): +        """Return disk usage statistics about the given path. + +        Returned valus is a named tuple with attributes 'total', 'used' and +        'free', which are the amount of total, used and free space, in bytes. +        """ +        total, free = nt._getdiskusage(path) +        used = total - free +        return _ntuple_diskusage(total, used, free) + + +def chown(path, user=None, group=None): +    """Change owner user and group of the given path. + +    user and group can be the uid/gid or the user/group names, and in that case, +    they are converted to their respective uid/gid. +    """ + +    if user is None and group is None: +        raise ValueError("user and/or group must be set") + +    _user = user +    _group = group + +    # -1 means don't change it +    if user is None: +        _user = -1 +    # user can either be an int (the uid) or a string (the system username) +    elif isinstance(user, str): +        _user = _get_uid(user) +        if _user is None: +            raise LookupError("no such user: {!r}".format(user)) + +    if group is None: +        _group = -1 +    elif not isinstance(group, int): +        _group = _get_gid(group) +        if _group is None: +            raise LookupError("no such group: {!r}".format(group)) + +    os.chown(path, _user, _group) + +def get_terminal_size(fallback=(80, 24)): +    """Get the size of the terminal window. + +    For each of the two dimensions, the environment variable, COLUMNS +    and LINES respectively, is checked. If the variable is defined and +    the value is a positive integer, it is used. + +    When COLUMNS or LINES is not defined, which is the common case, +    the terminal connected to sys.__stdout__ is queried +    by invoking os.get_terminal_size. + +    If the terminal size cannot be successfully queried, either because +    the system doesn't support querying, or because we are not +    connected to a terminal, the value given in fallback parameter +    is used. Fallback defaults to (80, 24) which is the default +    size used by many terminal emulators. + +    The value returned is a named tuple of type os.terminal_size. +    """ +    # columns, lines are the working values +    try: +        columns = int(os.environ['COLUMNS']) +    except (KeyError, ValueError): +        columns = 0 + +    try: +        lines = int(os.environ['LINES']) +    except (KeyError, ValueError): +        lines = 0 + +    # only query if necessary +    if columns <= 0 or lines <= 0: +        try: +            size = os.get_terminal_size(sys.__stdout__.fileno()) +        except (NameError, OSError): +            size = os.terminal_size(fallback) +        if columns <= 0: +            columns = size.columns +        if lines <= 0: +            lines = size.lines + +    return os.terminal_size((columns, lines)) + +def which(cmd, mode=os.F_OK | os.X_OK, path=None): +    """Given a command, mode, and a PATH string, return the path which +    conforms to the given mode on the PATH, or None if there is no such +    file. + +    `mode` defaults to os.F_OK | os.X_OK. `path` defaults to the result +    of os.environ.get("PATH"), or can be overridden with a custom search +    path. + +    """ +    # Check that a given file can be accessed with the correct mode. +    # Additionally check that `file` is not a directory, as on Windows +    # directories pass the os.access check. +    def _access_check(fn, mode): +        return (os.path.exists(fn) and os.access(fn, mode) +                and not os.path.isdir(fn)) + +    # Short circuit. If we're given a full path which matches the mode +    # and it exists, we're done here. +    if _access_check(cmd, mode): +        return cmd + +    path = (path or os.environ.get("PATH", os.defpath)).split(os.pathsep) + +    if sys.platform == "win32": +        # The current directory takes precedence on Windows. +        if not os.curdir in path: +            path.insert(0, os.curdir) + +        # PATHEXT is necessary to check on Windows. +        pathext = os.environ.get("PATHEXT", "").split(os.pathsep) +        # See if the given file matches any of the expected path extensions. +        # This will allow us to short circuit when given "python.exe". +        matches = [cmd for ext in pathext if cmd.lower().endswith(ext.lower())] +        # If it does match, only test that one, otherwise we have to try +        # others. +        files = [cmd] if matches else [cmd + ext.lower() for ext in pathext] +    else: +        # On other platforms you don't have things like PATHEXT to tell you +        # what file suffixes are executable, so just pass on cmd as-is. +        files = [cmd] + +    seen = set() +    for dir in path: +        dir = os.path.normcase(dir) +        if not dir in seen: +            seen.add(dir) +            for thefile in files: +                name = os.path.join(dir, thefile) +                if _access_check(name, mode): +                    return name +    return None | 
