diff options
Diffstat (limited to 'lib/git')
-rw-r--r-- | lib/git/__init__.py | 8 | ||||
-rw-r--r-- | lib/git/actor.py | 9 | ||||
-rw-r--r-- | lib/git/cmd.py | 93 | ||||
-rw-r--r-- | lib/git/config.py | 347 | ||||
-rw-r--r-- | lib/git/diff.py | 244 | ||||
-rw-r--r-- | lib/git/errors.py | 6 | ||||
-rw-r--r-- | lib/git/index.py | 987 | ||||
-rw-r--r-- | lib/git/objects/__init__.py | 1 | ||||
-rw-r--r-- | lib/git/objects/base.py | 103 | ||||
-rw-r--r-- | lib/git/objects/blob.py | 2 | ||||
-rw-r--r-- | lib/git/objects/commit.py | 140 | ||||
-rw-r--r-- | lib/git/objects/tag.py | 15 | ||||
-rw-r--r-- | lib/git/objects/tree.py | 60 | ||||
-rw-r--r-- | lib/git/objects/utils.py | 18 | ||||
-rw-r--r-- | lib/git/odict.py | 1399 | ||||
-rw-r--r-- | lib/git/refs.py | 613 | ||||
-rw-r--r-- | lib/git/remote.py | 718 | ||||
-rw-r--r-- | lib/git/repo.py | 406 | ||||
-rw-r--r-- | lib/git/utils.py | 274 |
19 files changed, 5044 insertions, 399 deletions
diff --git a/lib/git/__init__.py b/lib/git/__init__.py index 6f482128..c7efe5ea 100644 --- a/lib/git/__init__.py +++ b/lib/git/__init__.py @@ -9,17 +9,17 @@ import inspect __version__ = 'git' +from git.config import GitConfigParser from git.objects import * from git.refs import * from git.actor import Actor -from git.diff import Diff +from git.diff import * from git.errors import InvalidGitRepositoryError, NoSuchPathError, GitCommandError from git.cmd import Git from git.repo import Repo from git.stats import Stats -from git.utils import dashify -from git.utils import touch - +from git.remote import * +from git.index import * __all__ = [ name for name, obj in locals().items() if not (name.startswith('_') or inspect.ismodule(obj)) ] diff --git a/lib/git/actor.py b/lib/git/actor.py index fe4a47e5..04872f1c 100644 --- a/lib/git/actor.py +++ b/lib/git/actor.py @@ -18,6 +18,15 @@ class Actor(object): self.name = name self.email = email + def __eq__(self, other): + return self.name == other.name and self.email == other.email + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return hash((self.name, self.email)) + def __str__(self): return self.name diff --git a/lib/git/cmd.py b/lib/git/cmd.py index 500fcd93..fb6f2998 100644 --- a/lib/git/cmd.py +++ b/lib/git/cmd.py @@ -13,13 +13,16 @@ from errors import GitCommandError GIT_PYTHON_TRACE = os.environ.get("GIT_PYTHON_TRACE", False) execute_kwargs = ('istream', 'with_keep_cwd', 'with_extended_output', - 'with_exceptions', 'with_raw_output', 'as_process', + 'with_exceptions', 'as_process', 'output_stream' ) extra = {} if sys.platform == 'win32': extra = {'shell': True} +def dashify(string): + return string.replace('_', '-') + class Git(object): """ The Git class manages communication with the Git binary. @@ -39,12 +42,16 @@ class Git(object): """ Kill/Interrupt the stored process instance once this instance goes out of scope. It is used to prevent processes piling up in case iterators stop reading. - Besides all attributes are wired through to the contained process object + Besides all attributes are wired through to the contained process object. + + The wait method was overridden to perform automatic status code checking + and possibly raise. """ - __slots__= "proc" + __slots__= ("proc", "args") - def __init__(self, proc ): + def __init__(self, proc, args ): self.proc = proc + self.args = args def __del__(self): # did the process finish already so we have a return code ? @@ -61,6 +68,20 @@ class Git(object): def __getattr__(self, attr): return getattr(self.proc, attr) + + def wait(self): + """ + Wait for the process and return its status code. + + Raise + GitCommandError if the return status is not 0 + """ + status = self.proc.wait() + if status != 0: + raise GitCommandError(self.args, status, self.proc.stderr.read()) + # END status handling + return status + def __init__(self, git_dir=None): @@ -102,7 +123,6 @@ class Git(object): with_keep_cwd=False, with_extended_output=False, with_exceptions=True, - with_raw_output=False, as_process=False, output_stream=None ): @@ -129,27 +149,33 @@ class Git(object): ``with_exceptions`` Whether to raise an exception when git returns a non-zero status. - ``with_raw_output`` - Whether to avoid stripping off trailing whitespace. - ``as_process`` Whether to return the created process instance directly from which - streams can be read on demand. This will render with_extended_output, - with_exceptions and with_raw_output ineffective - the caller will have + streams can be read on demand. This will render with_extended_output and + with_exceptions ineffective - the caller will have to deal with the details himself. It is important to note that the process will be placed into an AutoInterrupt wrapper that will interrupt the process once it goes out of scope. If you use the command in iterators, you should pass the whole process instance instead of a single stream. + ``output_stream`` If set to a file-like object, data produced by the git command will be - output to the given stream directly. - Otherwise a new file will be opened. + output to the given stream directly. + This feature only has any effect if as_process is False. Processes will + always be created with a pipe due to issues with subprocess. + This merely is a workaround as data will be copied from the + output pipe to the given output stream directly. + Returns:: - str(output) # extended_output = False (Default) + str(output) # extended_output = False (Default) tuple(int(status), str(stdout), str(stderr)) # extended_output = True + + if ouput_stream is True, the stdout value will be your output stream: + output_stream # extended_output = False + tuple(int(status), output_stream, str(stderr))# extended_output = True Raise GitCommandError @@ -167,37 +193,39 @@ class Git(object): else: cwd=self.git_dir - ostream = subprocess.PIPE - if output_stream is not None: - ostream = output_stream - # Start the process proc = subprocess.Popen(command, cwd=cwd, stdin=istream, stderr=subprocess.PIPE, - stdout=ostream, + stdout=subprocess.PIPE, **extra ) - if as_process: - return self.AutoInterrupt(proc) + return self.AutoInterrupt(proc, command) # Wait for the process to return status = 0 try: - stdout_value = proc.stdout.read() - stderr_value = proc.stderr.read() + if output_stream is None: + stdout_value = proc.stdout.read().rstrip() # strip trailing "\n" + else: + max_chunk_size = 1024*64 + while True: + chunk = proc.stdout.read(max_chunk_size) + output_stream.write(chunk) + if len(chunk) < max_chunk_size: + break + # END reading output stream + stdout_value = output_stream + # END stdout handling + stderr_value = proc.stderr.read().rstrip() # strip trailing "\n" + # waiting here should do nothing as we have finished stream reading status = proc.wait() finally: proc.stdout.close() proc.stderr.close() - # Strip off trailing whitespace by default - if not with_raw_output: - stdout_value = stdout_value.rstrip() - stderr_value = stderr_value.rstrip() - if with_exceptions and status != 0: raise GitCommandError(command, status, stderr_value) @@ -258,7 +286,9 @@ class Git(object): such as in 'ls_files' to call 'ls-files'. ``args`` - is the list of arguments + is the list of arguments. If None is included, it will be pruned. + This allows your commands to call git more conveniently as None + is realized as non-existent ``kwargs`` is a dict of keyword arguments. @@ -284,7 +314,7 @@ class Git(object): # Prepare the argument list opt_args = self.transform_kwargs(**kwargs) - ext_args = self.__unpack_args(args) + ext_args = self.__unpack_args([a for a in args if a is not None]) args = opt_args + ext_args call = ["git", dashify(method)] @@ -306,8 +336,9 @@ class Git(object): """ tokens = header_line.split() if len(tokens) != 3: - raise ValueError( "SHA named %s could not be resolved" % tokens[0] ) - + raise ValueError("SHA named %s could not be resolved" % tokens[0] ) + if len(tokens[0]) != 40: + raise ValueError("Failed to parse header: %r" % header_line) return (tokens[0], tokens[1], int(tokens[2])) def __prepare_ref(self, ref): diff --git a/lib/git/config.py b/lib/git/config.py new file mode 100644 index 00000000..ccfbae48 --- /dev/null +++ b/lib/git/config.py @@ -0,0 +1,347 @@ +# config.py +# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors +# +# This module is part of GitPython and is released under +# the BSD License: http://www.opensource.org/licenses/bsd-license.php +""" +Module containing module parser implementation able to properly read and write +configuration files +""" + +import re +import os +import ConfigParser as cp +import inspect +import cStringIO + +from git.odict import OrderedDict +from git.utils import LockFile + +class _MetaParserBuilder(type): + """ + Utlity class wrapping base-class methods into decorators that assure read-only properties + """ + def __new__(metacls, name, bases, clsdict): + """ + Equip all base-class methods with a _needs_values decorator, and all non-const methods + with a _set_dirty_and_flush_changes decorator in addition to that. + """ + mutating_methods = clsdict['_mutating_methods_'] + for base in bases: + methods = ( t for t in inspect.getmembers(base, inspect.ismethod) if not t[0].startswith("_") ) + for name, method in methods: + if name in clsdict: + continue + method_with_values = _needs_values(method) + if name in mutating_methods: + method_with_values = _set_dirty_and_flush_changes(method_with_values) + # END mutating methods handling + + clsdict[name] = method_with_values + # END for each base + + new_type = super(_MetaParserBuilder, metacls).__new__(metacls, name, bases, clsdict) + return new_type + + + +def _needs_values(func): + """ + Returns method assuring we read values (on demand) before we try to access them + """ + def assure_data_present(self, *args, **kwargs): + self.read() + return func(self, *args, **kwargs) + # END wrapper method + assure_data_present.__name__ = func.__name__ + return assure_data_present + +def _set_dirty_and_flush_changes(non_const_func): + """ + Return method that checks whether given non constant function may be called. + If so, the instance will be set dirty. + Additionally, we flush the changes right to disk + """ + def flush_changes(self, *args, **kwargs): + rval = non_const_func(self, *args, **kwargs) + self.write() + return rval + # END wrapper method + flush_changes.__name__ = non_const_func.__name__ + return flush_changes + + + +class GitConfigParser(cp.RawConfigParser, LockFile): + """ + Implements specifics required to read git style configuration files. + + This variation behaves much like the git.config command such that the configuration + will be read on demand based on the filepath given during initialization. + + The changes will automatically be written once the instance goes out of scope, but + can be triggered manually as well. + + The configuration file will be locked if you intend to change values preventing other + instances to write concurrently. + + NOTE + The config is case-sensitive even when queried, hence section and option names + must match perfectly. + """ + __metaclass__ = _MetaParserBuilder + + OPTCRE = re.compile( + r'\s?(?P<option>[^:=\s][^:=]*)' # very permissive, incuding leading whitespace + r'\s*(?P<vi>[:=])\s*' # any number of space/tab, + # followed by separator + # (either : or =), followed + # by any # space/tab + r'(?P<value>.*)$' # everything up to eol + ) + + # list of RawConfigParser methods able to change the instance + _mutating_methods_ = ("add_section", "remove_section", "remove_option", "set") + __slots__ = ("_sections", "_defaults", "_file_or_files", "_read_only","_is_initialized") + + def __init__(self, file_or_files, read_only=True): + """ + Initialize a configuration reader to read the given file_or_files and to + possibly allow changes to it by setting read_only False + + ``file_or_files`` + A single file path or file objects or multiple of these + + ``read_only`` + If True, the ConfigParser may only read the data , but not change it. + If False, only a single file path or file object may be given. + """ + # initialize base with ordered dictionaries to be sure we write the same + # file back + self._sections = OrderedDict() + self._defaults = OrderedDict() + + self._file_or_files = file_or_files + self._read_only = read_only + self._is_initialized = False + + + if not read_only: + if isinstance(file_or_files, (tuple, list)): + raise ValueError("Write-ConfigParsers can operate on a single file only, multiple files have been passed") + # END single file check + + if not isinstance(file_or_files, basestring): + file_or_files = file_or_files.name + # END get filename from handle/stream + # initialize lock base - we want to write + LockFile.__init__(self, file_or_files) + + self._obtain_lock_or_raise() + # END read-only check + + + def __del__(self): + """ + Write pending changes if required and release locks + """ + # checking for the lock here makes sure we do not raise during write() + # in case an invalid parser was created who could not get a lock + if self.read_only or not self._has_lock(): + return + + try: + try: + self.write() + except IOError,e: + print "Exception during destruction of GitConfigParser: %s" % str(e) + finally: + self._release_lock() + + def optionxform(self, optionstr): + """ + Do not transform options in any way when writing + """ + return optionstr + + def _read(self, fp, fpname): + """ + A direct copy of the py2.4 version of the super class's _read method + to assure it uses ordered dicts. Had to change one line to make it work. + + Future versions have this fixed, but in fact its quite embarassing for the + guys not to have done it right in the first place ! + + Removed big comments to make it more compact. + + Made sure it ignores initial whitespace as git uses tabs + """ + cursect = None # None, or a dictionary + optname = None + lineno = 0 + e = None # None, or an exception + while True: + line = fp.readline() + if not line: + break + lineno = lineno + 1 + # comment or blank line? + if line.strip() == '' or line[0] in '#;': + continue + if line.split(None, 1)[0].lower() == 'rem' and line[0] in "rR": + # no leading whitespace + continue + else: + # is it a section header? + mo = self.SECTCRE.match(line) + if mo: + sectname = mo.group('header') + if sectname in self._sections: + cursect = self._sections[sectname] + elif sectname == cp.DEFAULTSECT: + cursect = self._defaults + else: + # THE ONLY LINE WE CHANGED ! + cursect = OrderedDict((('__name__', sectname),)) + self._sections[sectname] = cursect + # So sections can't start with a continuation line + optname = None + # no section header in the file? + elif cursect is None: + raise cp.MissingSectionHeaderError(fpname, lineno, line) + # an option line? + else: + mo = self.OPTCRE.match(line) + if mo: + optname, vi, optval = mo.group('option', 'vi', 'value') + if vi in ('=', ':') and ';' in optval: + pos = optval.find(';') + if pos != -1 and optval[pos-1].isspace(): + optval = optval[:pos] + optval = optval.strip() + if optval == '""': + optval = '' + optname = self.optionxform(optname.rstrip()) + cursect[optname] = optval + else: + if not e: + e = cp.ParsingError(fpname) + e.append(lineno, repr(line)) + # END + # END ? + # END ? + # END while reading + # if any parsing errors occurred, raise an exception + if e: + raise e + + + def read(self): + """ + Reads the data stored in the files we have been initialized with. It will + ignore files that cannot be read, possibly leaving an empty configuration + + Returns + Nothing + + Raises + IOError if a file cannot be handled + """ + if self._is_initialized: + return + + + files_to_read = self._file_or_files + if not isinstance(files_to_read, (tuple, list)): + files_to_read = [ files_to_read ] + + for file_object in files_to_read: + fp = file_object + close_fp = False + # assume a path if it is not a file-object + if not hasattr(file_object, "seek"): + try: + fp = open(file_object) + except IOError,e: + continue + close_fp = True + # END fp handling + + try: + self._read(fp, fp.name) + finally: + if close_fp: + fp.close() + # END read-handling + # END for each file object to read + self._is_initialized = True + + def _write(self, fp): + """Write an .ini-format representation of the configuration state in + git compatible format""" + def write_section(name, section_dict): + fp.write("[%s]\n" % name) + for (key, value) in section_dict.items(): + if key != "__name__": + fp.write("\t%s = %s\n" % (key, str(value).replace('\n', '\n\t'))) + # END if key is not __name__ + # END section writing + + if self._defaults: + write_section(cp.DEFAULTSECT, self._defaults) + map(lambda t: write_section(t[0],t[1]), self._sections.items()) + + + @_needs_values + def write(self): + """ + Write changes to our file, if there are changes at all + + Raise + IOError if this is a read-only writer instance or if we could not obtain + a file lock + """ + self._assure_writable("write") + self._obtain_lock_or_raise() + + + fp = self._file_or_files + close_fp = False + + if not hasattr(fp, "seek"): + fp = open(self._file_or_files, "w") + close_fp = True + else: + fp.seek(0) + + # WRITE DATA + try: + self._write(fp) + finally: + if close_fp: + fp.close() + # END data writing + + # we do not release the lock - it will be done automatically once the + # instance vanishes + + def _assure_writable(self, method_name): + if self.read_only: + raise IOError("Cannot execute non-constant method %s.%s" % (self, method_name)) + + @_needs_values + @_set_dirty_and_flush_changes + def add_section(self, section): + """ + Assures added options will stay in order + """ + super(GitConfigParser, self).add_section(section) + self._sections[section] = OrderedDict() + + @property + def read_only(self): + """ + Returns + True if this instance may change the configuration file + """ + return self._read_only diff --git a/lib/git/diff.py b/lib/git/diff.py index 0db83b4f..a43d3725 100644 --- a/lib/git/diff.py +++ b/lib/git/diff.py @@ -6,10 +6,152 @@ import re import objects.blob as blob +from errors import GitCommandError + +class Diffable(object): + """ + Common interface for all object that can be diffed against another object of compatible type. + + NOTE: + Subclasses require a repo member as it is the case for Object instances, for practical + reasons we do not derive from Object. + """ + __slots__ = tuple() + + # standin indicating you want to diff against the index + class Index(object): + pass + + def _process_diff_args(self, args): + """ + Returns + possibly altered version of the given args list. + Method is called right before git command execution. + Subclasses can use it to alter the behaviour of the superclass + """ + return args + + def diff(self, other=Index, paths=None, create_patch=False, **kwargs): + """ + Creates diffs between two items being trees, trees and index or an + index and the working tree. + + ``other`` + Is the item to compare us with. + If None, we will be compared to the working tree. + If Index ( type ), it will be compared against the index. + It defaults to Index to assure the method will not by-default fail + on bare repositories. + + ``paths`` + is a list of paths or a single path to limit the diff to. + It will only include at least one of the givne path or paths. + + ``create_patch`` + If True, the returned Diff contains a detailed patch that if applied + makes the self to other. Patches are somwhat costly as blobs have to be read + and diffed. + + ``kwargs`` + Additional arguments passed to git-diff, such as + R=True to swap both sides of the diff. + + Returns + git.DiffIndex + + Note + Rename detection will only work if create_patch is True. + + On a bare repository, 'other' needs to be provided as Index or as + as Tree/Commit, or a git command error will occour + """ + args = list() + args.append( "--abbrev=40" ) # we need full shas + args.append( "--full-index" ) # get full index paths, not only filenames + + if create_patch: + args.append("-p") + args.append("-M") # check for renames + else: + args.append("--raw") + + if paths is not None and not isinstance(paths, (tuple,list)): + paths = [ paths ] + + if other is not None and other is not self.Index: + args.insert(0, other) + if other is self.Index: + args.insert(0, "--cached") + + args.insert(0,self) + + # paths is list here or None + if paths: + args.append("--") + args.extend(paths) + # END paths handling + + kwargs['as_process'] = True + proc = self.repo.git.diff(*self._process_diff_args(args), **kwargs) + + diff_method = Diff._index_from_raw_format + if create_patch: + diff_method = Diff._index_from_patch_format + index = diff_method(self.repo, proc.stdout) + + status = proc.wait() + return index + + +class DiffIndex(list): + """ + Implements an Index for diffs, allowing a list of Diffs to be queried by + the diff properties. + + The class improves the diff handling convenience + """ + # change type invariant identifying possible ways a blob can have changed + # A = Added + # D = Deleted + # R = Renamed + # M = modified + change_type = ("A", "D", "R", "M") + + + def iter_change_type(self, change_type): + """ + Return + iterator yieling Diff instances that match the given change_type + + ``change_type`` + Member of DiffIndex.change_type, namely + + 'A' for added paths + + 'D' for deleted paths + + 'R' for renamed paths + + 'M' for paths with modified data + """ + if change_type not in self.change_type: + raise ValueError( "Invalid change type: %s" % change_type ) + + for diff in self: + if change_type == "A" and diff.new_file: + yield diff + elif change_type == "D" and diff.deleted_file: + yield diff + elif change_type == "R" and diff.renamed: + yield diff + elif change_type == "M" and diff.a_blob and diff.b_blob and diff.a_blob != diff.b_blob: + yield diff + # END for each diff + class Diff(object): """ - A Diff contains diff information between two commits. + A Diff contains diff information between two Trees. It contains two sides a and b of the diff, members are prefixed with "a" and "b" respectively to inidcate that. @@ -27,7 +169,7 @@ class Diff(object): ``Deleted File``:: b_mode is None - b_blob is NOne + b_blob is None """ # precompiled regex @@ -46,7 +188,7 @@ class Diff(object): """, re.VERBOSE | re.MULTILINE) re_is_null_hexsha = re.compile( r'^0{40}$' ) __slots__ = ("a_blob", "b_blob", "a_mode", "b_mode", "new_file", "deleted_file", - "rename_from", "rename_to", "renamed", "diff") + "rename_from", "rename_to", "diff") def __init__(self, repo, a_path, b_path, a_blob_id, b_blob_id, a_mode, b_mode, new_file, deleted_file, rename_from, @@ -54,39 +196,67 @@ class Diff(object): if not a_blob_id or self.re_is_null_hexsha.search(a_blob_id): self.a_blob = None else: - self.a_blob = blob.Blob(repo, id=a_blob_id, mode=a_mode, path=a_path) + self.a_blob = blob.Blob(repo, a_blob_id, mode=a_mode, path=a_path) if not b_blob_id or self.re_is_null_hexsha.search(b_blob_id): self.b_blob = None else: - self.b_blob = blob.Blob(repo, id=b_blob_id, mode=b_mode, path=b_path) + self.b_blob = blob.Blob(repo, b_blob_id, mode=b_mode, path=b_path) self.a_mode = a_mode self.b_mode = b_mode + if self.a_mode: self.a_mode = blob.Blob._mode_str_to_int( self.a_mode ) if self.b_mode: self.b_mode = blob.Blob._mode_str_to_int( self.b_mode ) + self.new_file = new_file self.deleted_file = deleted_file - self.rename_from = rename_from - self.rename_to = rename_to - self.renamed = rename_from != rename_to + + # be clear and use None instead of empty strings + self.rename_from = rename_from or None + self.rename_to = rename_to or None + self.diff = diff + + def __eq__(self, other): + for name in self.__slots__: + if getattr(self, name) != getattr(other, name): + return False + # END for each name + return True + + def __ne__(self, other): + return not ( self == other ) + + def __hash__(self): + return hash(tuple(getattr(self,n) for n in self.__slots__)) + + @property + def renamed(self): + """ + Returns: + True if the blob of our diff has been renamed + """ + return self.rename_from != self.rename_to + @classmethod - def _list_from_string(cls, repo, text): + def _index_from_patch_format(cls, repo, stream): """ - Create a new diff object from the given text + Create a new DiffIndex from the given text which must be in patch format ``repo`` is the repository we are operating on - it is required - ``text`` - result of 'git diff' between two commits or one commit and the index + ``stream`` + result of 'git diff' as a stream (supporting file protocol) Returns - git.Diff[] + git.DiffIndex """ - diffs = [] + # for now, we have to bake the stream + text = stream.read() + index = DiffIndex() diff_header = cls.re_header.match for diff in ('\n' + text).split('\ndiff --git')[1:]: @@ -97,9 +267,51 @@ class Diff(object): a_blob_id, b_blob_id, b_mode = header.groups() new_file, deleted_file = bool(new_file_mode), bool(deleted_file_mode) - diffs.append(Diff(repo, a_path, b_path, a_blob_id, b_blob_id, + index.append(Diff(repo, a_path, b_path, a_blob_id, b_blob_id, old_mode or deleted_file_mode, new_mode or new_file_mode or b_mode, new_file, deleted_file, rename_from, rename_to, diff[header.end():])) - return diffs + return index + + @classmethod + def _index_from_raw_format(cls, repo, stream): + """ + Create a new DiffIndex from the given stream which must be in raw format. + + NOTE: + This format is inherently incapable of detecting renames, hence we only + modify, delete and add files + + Returns + git.DiffIndex + """ + # handles + # :100644 100644 6870991011cc8d9853a7a8a6f02061512c6a8190 37c5e30c879213e9ae83b21e9d11e55fc20c54b7 M .gitignore + index = DiffIndex() + for line in stream: + if not line.startswith(":"): + continue + # END its not a valid diff line + old_mode, new_mode, a_blob_id, b_blob_id, change_type, path = line[1:].split() + a_path = path + b_path = path + deleted_file = False + new_file = False + + # NOTE: We cannot conclude from the existance of a blob to change type + # as diffs with the working do not have blobs yet + if change_type == 'D': + b_path = None + deleted_file = True + elif change_type == 'A': + a_path = None + new_file = True + # END add/remove handling + + diff = Diff(repo, a_path, b_path, a_blob_id, b_blob_id, old_mode, new_mode, + new_file, deleted_file, None, None, '') + index.append(diff) + # END for each line + + return index diff --git a/lib/git/errors.py b/lib/git/errors.py index e9a637c0..cde2798a 100644 --- a/lib/git/errors.py +++ b/lib/git/errors.py @@ -25,8 +25,8 @@ class GitCommandError(Exception): self.stderr = stderr self.status = status self.command = command - + def __str__(self): - return repr("%s returned exit status %d" % - (str(self.command), self.status)) + return ("'%s' returned exit status %i: %s" % + (' '.join(str(i) for i in self.command), self.status, self.stderr)) diff --git a/lib/git/index.py b/lib/git/index.py new file mode 100644 index 00000000..e368f531 --- /dev/null +++ b/lib/git/index.py @@ -0,0 +1,987 @@ +# index.py +# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors +# +# This module is part of GitPython and is released under +# the BSD License: http://www.opensource.org/licenses/bsd-license.php +""" +Module containing Index implementation, allowing to perform all kinds of index +manipulations such as querying and merging. +""" +import struct +import binascii +import mmap +import objects +import tempfile +import os +import sys +import stat +import subprocess +import git.diff as diff + +from git.objects import Blob, Tree, Object, Commit +from git.utils import SHA1Writer, LazyMixin, ConcurrentWriteOperation + + +class _TemporaryFileSwap(object): + """ + Utility class moving a file to a temporary location within the same directory + and moving it back on to where on object deletion. + """ + __slots__ = ("file_path", "tmp_file_path") + + def __init__(self, file_path): + self.file_path = file_path + self.tmp_file_path = self.file_path + tempfile.mktemp('','','') + os.rename(self.file_path, self.tmp_file_path) + + def __del__(self): + if os.path.isfile(self.tmp_file_path): + if sys.platform == "win32" and os.path.exists(self.file_path): + os.remove(self.file_path) + os.rename(self.tmp_file_path, self.file_path) + # END temp file exists + + +class BaseIndexEntry(tuple): + """ + + Small Brother of an index entry which can be created to describe changes + done to the index in which case plenty of additional information is not requried. + + As the first 4 data members match exactly to the IndexEntry type, methods + expecting a BaseIndexEntry can also handle full IndexEntries even if they + use numeric indices for performance reasons. + """ + + def __str__(self): + return "%o %s %i\t%s\n" % (self.mode, self.sha, self.stage, self.path) + + @property + def mode(self): + """ + File Mode, compatible to stat module constants + """ + return self[0] + + @property + def sha(self): + """ + hex sha of the blob + """ + return self[1] + + @property + def stage(self): + """ + Stage of the entry, either: + 0 = default stage + 1 = stage before a merge or common ancestor entry in case of a 3 way merge + 2 = stage of entries from the 'left' side of the merge + 3 = stage of entries from the right side of the merge + Note: + For more information, see http://www.kernel.org/pub/software/scm/git/docs/git-read-tree.html + """ + return self[2] + + @property + def path(self): + return self[3] + + @classmethod + def from_blob(cls, blob, stage = 0): + """ + Returns + Fully equipped BaseIndexEntry at the given stage + """ + return cls((blob.mode, blob.sha, stage, blob.path)) + + +class IndexEntry(BaseIndexEntry): + """ + Allows convenient access to IndexEntry data without completely unpacking it. + + Attributes usully accessed often are cached in the tuple whereas others are + unpacked on demand. + + See the properties for a mapping between names and tuple indices. + """ + @property + def ctime(self): + """ + Returns + Tuple(int_time_seconds_since_epoch, int_nano_seconds) of the + file's creation time + """ + return struct.unpack(">LL", self[4]) + + @property + def mtime(self): + """ + See ctime property, but returns modification time + """ + return struct.unpack(">LL", self[5]) + + @property + def dev(self): + """ + Device ID + """ + return self[6] + + @property + def inode(self): + """ + Inode ID + """ + return self[7] + + @property + def uid(self): + """ + User ID + """ + return self[8] + + @property + def gid(self): + """ + Group ID + """ + return self[9] + + @property + def size(self): + """ + Uncompressed size of the blob + + Note + Will be 0 if the stage is not 0 ( hence it is an unmerged entry ) + """ + return self[10] + + @classmethod + def from_blob(cls, blob): + """ + Returns + Minimal entry resembling the given blob objecft + """ + time = struct.pack(">LL", 0, 0) + return IndexEntry((blob.mode, blob.sha, 0, blob.path, time, time, 0, 0, 0, 0, blob.size)) + + +def clear_cache(func): + """ + Decorator for functions that alter the index using the git command. This would + invalidate our possibly existing entries dictionary which is why it must be + deleted to allow it to be lazily reread later. + + Note + This decorator will not be required once all functions are implemented + natively which in fact is possible, but probably not feasible performance wise. + """ + def clear_cache_if_not_raised(self, *args, **kwargs): + rval = func(self, *args, **kwargs) + del(self.entries) + return rval + + # END wrapper method + clear_cache_if_not_raised.__name__ = func.__name__ + return clear_cache_if_not_raised + + +def default_index(func): + """ + Decorator assuring the wrapped method may only run if we are the default + repository index. This is as we rely on git commands that operate + on that index only. + """ + def check_default_index(self, *args, **kwargs): + if self._file_path != self._index_path(): + raise AssertionError( "Cannot call %r on indices that do not represent the default git index" % func.__name__ ) + return func(self, *args, **kwargs) + # END wrpaper method + + check_default_index.__name__ = func.__name__ + return check_default_index + + +class IndexFile(LazyMixin, diff.Diffable): + """ + Implements an Index that can be manipulated using a native implementation in + order to save git command function calls wherever possible. + + It provides custom merging facilities allowing to merge without actually changing + your index or your working tree. This way you can perform own test-merges based + on the index only without having to deal with the working copy. This is useful + in case of partial working trees. + + ``Entries`` + The index contains an entries dict whose keys are tuples of type IndexEntry + to facilitate access. + + You may only read the entries dict or manipulate it through designated methods. + Otherwise changes to it will be lost when changing the index using its methods. + """ + __slots__ = ( "repo", "version", "entries", "_extension_data", "_file_path" ) + _VERSION = 2 # latest version we support + S_IFGITLINK = 0160000 + + def __init__(self, repo, file_path=None): + """ + Initialize this Index instance, optionally from the given ``file_path``. + If no file_path is given, we will be created from the current index file. + + If a stream is not given, the stream will be initialized from the current + repository's index on demand. + """ + self.repo = repo + self.version = self._VERSION + self._extension_data = '' + self._file_path = file_path or self._index_path() + + def _set_cache_(self, attr): + if attr == "entries": + # read the current index + # try memory map for speed + fp = open(self._file_path, "r") + stream = fp + try: + stream = mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ) + except Exception: + pass + # END memory mapping + + try: + self._read_from_stream(stream) + finally: + fp.close() + # END read from default index on demand + else: + super(IndexFile, self)._set_cache_(attr) + + def _index_path(self): + return os.path.join(self.repo.path, "index") + + + @property + def path(self): + """ + Returns + Path to the index file we are representing + """ + return self._file_path + + @classmethod + def _read_entry(cls, stream): + """Return: One entry of the given stream""" + beginoffset = stream.tell() + ctime = struct.unpack(">8s", stream.read(8))[0] + mtime = struct.unpack(">8s", stream.read(8))[0] + (dev, ino, mode, uid, gid, size, sha, flags) = \ + struct.unpack(">LLLLLL20sH", stream.read(20 + 4 * 6 + 2)) + path_size = flags & 0x0fff + path = stream.read(path_size) + + real_size = ((stream.tell() - beginoffset + 8) & ~7) + data = stream.read((beginoffset + real_size) - stream.tell()) + return IndexEntry((mode, binascii.hexlify(sha), flags >> 12, path, ctime, mtime, dev, ino, uid, gid, size)) + + @classmethod + def _read_header(cls, stream): + """Return tuple(version_long, num_entries) from the given stream""" + type_id = stream.read(4) + if type_id != "DIRC": + raise AssertionError("Invalid index file header: %r" % type_id) + version, num_entries = struct.unpack(">LL", stream.read(4 * 2)) + assert version in (1, 2) + return version, num_entries + + def _read_from_stream(self, stream): + """ + Initialize this instance with index values read from the given stream + """ + self.version, num_entries = self._read_header(stream) + count = 0 + self.entries = dict() + while count < num_entries: + entry = self._read_entry(stream) + self.entries[(entry.path, entry.stage)] = entry + count += 1 + # END for each entry + + # the footer contains extension data and a sha on the content so far + # Keep the extension footer,and verify we have a sha in the end + self._extension_data = stream.read(~0) + assert len(self._extension_data) > 19, "Index Footer was not at least a sha on content as it was only %i bytes in size" % len(self._extension_data) + + content_sha = self._extension_data[-20:] + + # truncate the sha in the end as we will dynamically create it anyway + self._extension_data = self._extension_data[:-20] + + + @classmethod + def _write_cache_entry(cls, stream, entry): + """ + Write an IndexEntry to a stream + """ + beginoffset = stream.tell() + stream.write(entry[4]) # ctime + stream.write(entry[5]) # mtime + path = entry[3] + plen = len(path) & 0x0fff # path length + assert plen == len(path), "Path %s too long to fit into index" % entry[3] + flags = plen | (entry[2] << 12)# stage and path length are 2 byte flags + stream.write(struct.pack(">LLLLLL20sH", entry[6], entry[7], entry[0], + entry[8], entry[9], entry[10], binascii.unhexlify(entry[1]), flags)) + stream.write(path) + real_size = ((stream.tell() - beginoffset + 8) & ~7) + stream.write("\0" * ((beginoffset + real_size) - stream.tell())) + + def write(self, file_path = None): + """ + Write the current state to our file path or to the given one + + ``file_path`` + If None, we will write to our stored file path from which we have + been initialized. Otherwise we write to the given file path. + Please note that this will change the file_path of this index to + the one you gave. + + Returns + self + + Note + Index writing based on the dulwich implementation + """ + write_op = ConcurrentWriteOperation(file_path or self._file_path) + stream = write_op._begin_writing() + + stream = SHA1Writer(stream) + + # header + stream.write("DIRC") + stream.write(struct.pack(">LL", self.version, len(self.entries))) + + # body + entries_sorted = self.entries.values() + entries_sorted.sort(key=lambda e: (e[3], e[2])) # use path/stage as sort key + for entry in entries_sorted: + self._write_cache_entry(stream, entry) + # END for each entry + + # write previously cached extensions data + stream.write(self._extension_data) + + # write the sha over the content + stream.write_sha() + write_op._end_writing() + + # make sure we represent what we have written + if file_path is not None: + self._file_path = file_path + + @classmethod + def from_tree(cls, repo, *treeish, **kwargs): + """ + Merge the given treeish revisions into a new index which is returned. + The original index will remain unaltered + + ``repo`` + The repository treeish are located in. + + ``*treeish`` + One, two or three Tree Objects or Commits. The result changes according to the + amount of trees. + If 1 Tree is given, it will just be read into a new index + If 2 Trees are given, they will be merged into a new index using a + two way merge algorithm. Tree 1 is the 'current' tree, tree 2 is the 'other' + one. It behaves like a fast-forward. + If 3 Trees are given, a 3-way merge will be performed with the first tree + being the common ancestor of tree 2 and tree 3. Tree 2 is the 'current' tree, + tree 3 is the 'other' one + + ``**kwargs`` + Additional arguments passed to git-read-tree + + Returns + New IndexFile instance. It will point to a temporary index location which + does not exist anymore. If you intend to write such a merged Index, supply + an alternate file_path to its 'write' method. + + Note: + In the three-way merge case, --aggressive will be specified to automatically + resolve more cases in a commonly correct manner. Specify trivial=True as kwarg + to override that. + + As the underlying git-read-tree command takes into account the current index, + it will be temporarily moved out of the way to assure there are no unsuspected + interferences. + """ + if len(treeish) == 0 or len(treeish) > 3: + raise ValueError("Please specify between 1 and 3 treeish, got %i" % len(treeish)) + + arg_list = list() + # ignore that working tree and index possibly are out of date + if len(treeish)>1: + # drop unmerged entries when reading our index and merging + arg_list.append("--reset") + # handle non-trivial cases the way a real merge does + arg_list.append("--aggressive") + # END merge handling + + # tmp file created in git home directory to be sure renaming + # works - /tmp/ dirs could be on another device + tmp_index = tempfile.mktemp('','',repo.path) + arg_list.append("--index-output=%s" % tmp_index) + arg_list.extend(treeish) + + # move current index out of the way - otherwise the merge may fail + # as it considers existing entries. moving it essentially clears the index. + # Unfortunately there is no 'soft' way to do it. + # The _TemporaryFileSwap assure the original file get put back + index_handler = _TemporaryFileSwap(os.path.join(repo.path, 'index')) + try: + repo.git.read_tree(*arg_list, **kwargs) + index = cls(repo, tmp_index) + index.entries # force it to read the file + finally: + if os.path.exists(tmp_index): + os.remove(tmp_index) + # END index merge handling + + return index + + @classmethod + def _index_mode_to_tree_index_mode(cls, index_mode): + """ + Cleanup a index_mode value. + This will return a index_mode that can be stored in a tree object. + + ``index_mode`` + Index_mode to clean up. + """ + if stat.S_ISLNK(index_mode): + return stat.S_IFLNK + elif stat.S_ISDIR(index_mode): + return stat.S_IFDIR + elif stat.S_IFMT(index_mode) == cls.S_IFGITLINK: + return cls.S_IFGITLINK + ret = stat.S_IFREG | 0644 + ret |= (index_mode & 0111) + return ret + + @classmethod + def _tree_mode_to_index_mode(cls, tree_mode): + """ + Convert a tree mode to index mode as good as possible + """ + + def iter_blobs(self, predicate = lambda t: True): + """ + Returns + Iterator yielding tuples of Blob objects and stages, tuple(stage, Blob) + + ``predicate`` + Function(t) returning True if tuple(stage, Blob) should be yielded by the + iterator + """ + for entry in self.entries.itervalues(): + mode = self._index_mode_to_tree_index_mode(entry.mode) + blob = Blob(self.repo, entry.sha, mode, entry.path) + blob.size = entry.size + output = (entry.stage, blob) + if predicate(output): + yield output + # END for each entry + + def unmerged_blobs(self): + """ + Returns + Iterator yielding dict(path : list( tuple( stage, Blob, ...))), being + a dictionary associating a path in the index with a list containing + stage/blob pairs + + Note: + Blobs that have been removed in one side simply do not exist in the + given stage. I.e. a file removed on the 'other' branch whose entries + are at stage 3 will not have a stage 3 entry. + """ + is_unmerged_blob = lambda t: t[0] != 0 + path_map = dict() + for stage, blob in self.iter_blobs(is_unmerged_blob): + path_map.setdefault(blob.path, list()).append((stage, blob)) + # END for each unmerged blob + + return path_map + + def resolve_blobs(self, iter_blobs): + """ + Resolve the blobs given in blob iterator. This will effectively remove the + index entries of the respective path at all non-null stages and add the given + blob as new stage null blob. + + For each path there may only be one blob, otherwise a ValueError will be raised + claiming the path is already at stage 0. + + Raise + ValueError if one of the blobs already existed at stage 0 + + Returns: + self + + Note + You will have to write the index manually once you are done, i.e. + index.resolve_blobs(blobs).write() + """ + for blob in iter_blobs: + stage_null_key = (blob.path, 0) + if stage_null_key in self.entries: + raise ValueError( "Blob %r already at stage 0" % blob ) + # END assert blob is not stage 0 already + + # delete all possible stages + for stage in (1, 2, 3): + try: + del( self.entries[(blob.path, stage)] ) + except KeyError: + pass + # END ignore key errors + # END for each possible stage + + self.entries[stage_null_key] = IndexEntry.from_blob(blob) + # END for each blob + + return self + + def update(self): + """ + Reread the contents of our index file, discarding all cached information + we might have. + + Note: + This is a possibly dangerious operations as it will discard your changes + to index.entries + + Returns + self + """ + del(self.entries) + # allows to lazily reread on demand + return self + + def write_tree(self): + """ + Writes the Index in self to a corresponding Tree file into the repository + object database and returns it as corresponding Tree object. + + Returns + Tree object representing this index + """ + index_path = self._index_path() + tmp_index_mover = _TemporaryFileSwap(index_path) + + self.write(index_path) + tree_sha = self.repo.git.write_tree() + + return Tree(self.repo, tree_sha, 0, '') + + def _process_diff_args(self, args): + try: + args.pop(args.index(self)) + except IndexError: + pass + # END remove self + return args + + + def _to_relative_path(self, path): + """ + Return + Version of path relative to our git directory or raise ValueError + if it is not within our git direcotory + """ + if not os.path.isabs(path): + return path + relative_path = path.replace(self.repo.git.git_dir+"/", "") + if relative_path == path: + raise ValueError("Absolute path %r is not in git repository at %r" % (path,self.repo.git.git_dir)) + return relative_path + + def _preprocess_add_items(self, items): + """ + Split the items into two lists of path strings and BaseEntries. + """ + paths = list() + entries = list() + + for item in items: + if isinstance(item, basestring): + paths.append(self._to_relative_path(item)) + elif isinstance(item, Blob): + entries.append(BaseIndexEntry.from_blob(item)) + elif isinstance(item, BaseIndexEntry): + entries.append(item) + else: + raise TypeError("Invalid Type: %r" % item) + # END for each item + return (paths, entries) + + + @clear_cache + @default_index + def add(self, items, force=True, **kwargs): + """ + Add files from the working tree, specific blobs or BaseIndexEntries + to the index. The underlying index file will be written immediately, hence + you should provide as many items as possible to minimize the amounts of writes + + ``items`` + Multiple types of items are supported, types can be mixed within one call. + Different types imply a different handling. File paths may generally be + relative or absolute. + + - path string + strings denote a relative or absolute path into the repository pointing to + an existing file, i.e. CHANGES, lib/myfile.ext, '/home/gitrepo/lib/myfile.ext'. + + Paths provided like this must exist. When added, they will be written + into the object database. + + PathStrings may contain globs, such as 'lib/__init__*' or can be directories + like 'lib', the latter ones will add all the files within the dirctory and + subdirectories. + + This equals a straight git-add. + + They are added at stage 0 + + - Blob object + Blobs are added as they are assuming a valid mode is set. + The file they refer to may or may not exist in the file system, but + must be a path relative to our repository. + + If their sha is null ( 40*0 ), their path must exist in the file system + as an object will be created from the data at the path.The handling + now very much equals the way string paths are processed, except that + the mode you have set will be kept. This allows you to create symlinks + by settings the mode respectively and writing the target of the symlink + directly into the file. This equals a default Linux-Symlink which + is not dereferenced automatically, except that it can be created on + filesystems not supporting it as well. + + Please note that globs or directories are not allowed in Blob objects. + + They are added at stage 0 + + - BaseIndexEntry or type + Handling equals the one of Blob objects, but the stage may be + explicitly set. + + ``force`` + If True, otherwise ignored or excluded files will be + added anyway. + As opposed to the git-add command, we enable this flag by default + as the API user usually wants the item to be added even though + they might be excluded. + + ``**kwargs`` + Additional keyword arguments to be passed to git-update-index, such + as index_only. + + Returns + List(BaseIndexEntries) representing the entries just actually added. + """ + # sort the entries into strings and Entries, Blobs are converted to entries + # automatically + # paths can be git-added, for everything else we use git-update-index + entries_added = list() + paths, entries = self._preprocess_add_items(items) + + if paths: + git_add_output = self.repo.git.add(paths, v=True) + # force rereading our entries + del(self.entries) + for line in git_add_output.splitlines(): + # line contains: + # add '<path>' + added_file = line[5:-1] + entries_added.append(self.entries[(added_file,0)]) + # END for each line + # END path handling + + if entries: + null_mode_entries = [ e for e in entries if e.mode == 0 ] + if null_mode_entries: + raise ValueError("At least one Entry has a null-mode - please use index.remove to remove files for clarity") + # END null mode should be remove + + # create objects if required, otherwise go with the existing shas + null_entries_indices = [ i for i,e in enumerate(entries) if e.sha == Object.NULL_HEX_SHA ] + if null_entries_indices: + hash_proc = self.repo.git.hash_object(w=True, stdin_paths=True, istream=subprocess.PIPE, as_process=True) + hash_proc.stdin.write('\n'.join(entries[i].path for i in null_entries_indices)) + obj_ids = self._flush_stdin_and_wait(hash_proc).splitlines() + assert len(obj_ids) == len(null_entries_indices), "git-hash-object did not produce all requested objects: want %i, got %i" % ( len(null_entries_indices), len(obj_ids) ) + + # update IndexEntries with new object id + for i,new_sha in zip(null_entries_indices, obj_ids): + e = entries[i] + new_entry = BaseIndexEntry((e.mode, new_sha, e.stage, e.path)) + entries[i] = new_entry + # END for each index + # END null_entry handling + + # feed all the data to stdin + update_index_proc = self.repo.git.update_index(index_info=True, istream=subprocess.PIPE, as_process=True, **kwargs) + update_index_proc.stdin.write('\n'.join(str(e) for e in entries)) + entries_added.extend(entries) + self._flush_stdin_and_wait(update_index_proc) + # END if there are base entries + + return entries_added + + @clear_cache + @default_index + def remove(self, items, working_tree=False, **kwargs): + """ + Remove the given items from the index and optionally from + the working tree as well. + + ``items`` + Multiple types of items are supported which may be be freely mixed. + + - path string + Remove the given path at all stages. If it is a directory, you must + specify the r=True keyword argument to remove all file entries + below it. If absolute paths are given, they will be converted + to a path relative to the git repository directory containing + the working tree + + The path string may include globs, such as *.c. + + - Blob object + Only the path portion is used in this case. + + - BaseIndexEntry or compatible type + The only relevant information here Yis the path. The stage is ignored. + + ``working_tree`` + If True, the entry will also be removed from the working tree, physically + removing the respective file. This may fail if there are uncommited changes + in it. + + ``**kwargs`` + Additional keyword arguments to be passed to git-rm, such + as 'r' to allow recurive removal of + + Returns + List(path_string, ...) list of paths that have been removed effectively. + This is interesting to know in case you have provided a directory or + globs. Paths are relative to the + """ + args = list() + if not working_tree: + args.append("--cached") + args.append("--") + + # preprocess paths + paths = list() + for item in items: + if isinstance(item, (BaseIndexEntry,Blob)): + paths.append(self._to_relative_path(item.path)) + elif isinstance(item, basestring): + paths.append(self._to_relative_path(item)) + else: + raise TypeError("Invalid item type: %r" % item) + # END for each item + + removed_paths = self.repo.git.rm(args, paths, **kwargs).splitlines() + + # process output to gain proper paths + # rm 'path' + return [ p[4:-1] for p in removed_paths ] + + @default_index + def commit(self, message, parent_commits=None, head=True): + """ + Commit the current index, creating a commit object. + + ``message`` + Commit message. It may be an empty string if no message is provided. + It will be converted to a string in any case. + + ``parent_commits`` + Optional Commit objects to use as parents for the new commit. + If empty list, the commit will have no parents at all and become + a root commit. + If None , the current head commit will be the parent of the + new commit object + + ``head`` + If True, the HEAD will be advanced to the new commit automatically. + Else the HEAD will remain pointing on the previous commit. This could + lead to undesired results when diffing files. + + Returns + Commit object representing the new commit + + Note: + Additional information about hte committer and Author are taken from the + environment or from the git configuration, see git-commit-tree for + more information + """ + parents = parent_commits + if parent_commits is None: + parent_commits = [ self.repo.head.commit ] + + parent_args = [ ("-p", str(commit)) for commit in parent_commits ] + + # create message stream + tmp_file_path = tempfile.mktemp() + fp = open(tmp_file_path,"w") + fp.write(str(message)) + fp.close() + fp = open(tmp_file_path,"r") + fp.seek(0) + + try: + # write the current index as tree + tree_sha = self.repo.git.write_tree() + commit_sha = self.repo.git.commit_tree(tree_sha, parent_args, istream=fp) + new_commit = Commit(self.repo, commit_sha) + + if head: + self.repo.head.commit = new_commit + # END advance head handling + + return new_commit + finally: + fp.close() + os.remove(tmp_file_path) + + @classmethod + def _flush_stdin_and_wait(cls, proc): + proc.stdin.flush() + proc.stdin.close() + stdout = proc.stdout.read() + proc.wait() + return stdout + + @default_index + def checkout(self, paths=None, force=False, **kwargs): + """ + Checkout the given paths or all files from the version in the index. + + ``paths`` + If None, all paths in the index will be checked out. Otherwise an iterable + or single path of relative or absolute paths pointing to files is expected. + The command will ignore paths that do not exist. + + ``force`` + If True, existing files will be overwritten. If False, these will + be skipped. + + ``**kwargs`` + Additional arguments to be pasesd to git-checkout-index + + Returns + self + """ + args = ["--index"] + if force: + args.append("--force") + + if paths is None: + args.append("--all") + self.repo.git.checkout_index(*args, **kwargs) + else: + if not isinstance(paths, (tuple,list)): + paths = [paths] + + args.append("--stdin") + paths = [self._to_relative_path(p) for p in paths] + co_proc = self.repo.git.checkout_index(args, as_process=True, istream=subprocess.PIPE, **kwargs) + co_proc.stdin.write('\n'.join(paths)) + self._flush_stdin_and_wait(co_proc) + # END paths handling + return self + + @clear_cache + @default_index + def reset(self, commit='HEAD', working_tree=False, paths=None, head=False, **kwargs): + """ + Reset the index to reflect the tree at the given commit. This will not + adjust our HEAD reference as opposed to HEAD.reset by default. + + ``commit`` + Revision, Reference or Commit specifying the commit we should represent. + If you want to specify a tree only, use IndexFile.from_tree and overwrite + the default index. + + ``working_tree`` + If True, the files in the working tree will reflect the changed index. + If False, the working tree will not be touched + Please note that changes to the working copy will be discarded without + warning ! + + ``head`` + If True, the head will be set to the given commit. This is False by default, + but if True, this method behaves like HEAD.reset. + + ``**kwargs`` + Additional keyword arguments passed to git-reset + + Returns + self + """ + cur_head = self.repo.head + prev_commit = cur_head.commit + + # reset to get the tree/working copy + cur_head.reset(commit, index=True, working_tree=working_tree, paths=paths, **kwargs) + + # put the head back, possibly + if not head: + cur_head.reset(prev_commit, index=False, working_tree=False) + # END reset head + + return self + + @default_index + def diff(self, other=diff.Diffable.Index, paths=None, create_patch=False, **kwargs): + """ + Diff this index against the working copy or a Tree or Commit object + + For a documentation of the parameters and return values, see + Diffable.diff + + Note + Will only work with indices that represent the default git index as + they have not been initialized with a stream. + """ + # index against index is always empty + if other is self.Index: + return diff.DiffIndex() + + # index against anything but None is a reverse diff with the respective + # item. Handle existing -R flags properly. Transform strings to the object + # so that we can call diff on it + if isinstance(other, basestring): + other = Object.new(self.repo, other) + # END object conversion + + if isinstance(other, Object): + # invert the existing R flag + cur_val = kwargs.get('R', False) + kwargs['R'] = not cur_val + return other.diff(self.Index, paths, create_patch, **kwargs) + # END diff against other item handlin + + # if other is not None here, something is wrong + if other is not None: + raise ValueError( "other must be None, Diffable.Index, a Tree or Commit, was %r" % other ) + + # diff against working copy - can be handled by superclass natively + return super(IndexFile, self).diff(other, paths, create_patch, **kwargs) + diff --git a/lib/git/objects/__init__.py b/lib/git/objects/__init__.py index 39e650b7..192750e3 100644 --- a/lib/git/objects/__init__.py +++ b/lib/git/objects/__init__.py @@ -2,6 +2,7 @@ Import all submodules main classes into the package space """ import inspect +from base import * from tag import * from blob import * from tree import * diff --git a/lib/git/objects/base.py b/lib/git/objects/base.py index 3b48e066..b0989a43 100644 --- a/lib/git/objects/base.py +++ b/lib/git/objects/base.py @@ -15,22 +15,16 @@ class Object(LazyMixin): This Object also serves as a constructor for instances of the correct type:: - inst = Object(repo,id) + inst = Object.new(repo,id) + inst.sha # objects sha in hex + inst.size # objects uncompressed data size + inst.data # byte string containing the whole data of the object """ + NULL_HEX_SHA = '0'*40 TYPES = ("blob", "tree", "commit", "tag") - __slots__ = ("repo", "id", "size", "data" ) + __slots__ = ("repo", "sha", "size", "data" ) type = None # to be set by subclass - def __new__(cls, repo, id, *args, **kwargs): - if cls is Object: - hexsha, typename, size = repo.git.get_object_header(id) - obj_type = utils.get_object_type_by_name(typename) - inst = super(Object,cls).__new__(obj_type, repo, hexsha, *args, **kwargs) - inst.size = size - return inst - else: - return super(Object,cls).__new__(cls, repo, id, *args, **kwargs) - def __init__(self, repo, id): """ Initialize an object by identifying it by its id. All keyword arguments @@ -44,8 +38,26 @@ class Object(LazyMixin): """ super(Object,self).__init__() self.repo = repo - self.id = id - + self.sha = id + + @classmethod + def new(cls, repo, id): + """ + Return + New Object instance of a type appropriate to the object type behind + id. The id of the newly created object will be a hexsha even though + the input id may have been a Reference or Rev-Spec + + Note + This cannot be a __new__ method as it would always call __init__ + with the input id which is not necessarily a hexsha. + """ + hexsha, typename, size = repo.git.get_object_header(id) + obj_type = utils.get_object_type_by_name(typename) + inst = obj_type(repo, hexsha) + inst.size = size + return inst + def _set_self_from_args_(self, args_dict): """ Initialize attributes on self from the given dict that was retrieved @@ -64,11 +76,11 @@ class Object(LazyMixin): Retrieve object information """ if attr == "size": - hexsha, typename, self.size = self.repo.git.get_object_header(self.id) - assert typename == self.type, _assertion_msg_format % (self.id, typename, self.type) + hexsha, typename, self.size = self.repo.git.get_object_header(self.sha) + assert typename == self.type, _assertion_msg_format % (self.sha, typename, self.type) elif attr == "data": - hexsha, typename, self.size, self.data = self.repo.git.get_object_data(self.id) - assert typename == self.type, _assertion_msg_format % (self.id, typename, self.type) + hexsha, typename, self.size, self.data = self.repo.git.get_object_data(self.sha) + assert typename == self.type, _assertion_msg_format % (self.sha, typename, self.type) else: super(Object,self)._set_cache_(attr) @@ -77,36 +89,57 @@ class Object(LazyMixin): Returns True if the objects have the same SHA1 """ - return self.id == other.id + return self.sha == other.sha def __ne__(self, other): """ Returns True if the objects do not have the same SHA1 """ - return self.id != other.id + return self.sha != other.sha def __hash__(self): """ Returns Hash of our id allowing objects to be used in dicts and sets """ - return hash(self.id) + return hash(self.sha) def __str__(self): """ Returns string of our SHA1 as understood by all git commands """ - return self.id + return self.sha def __repr__(self): """ Returns string with pythonic representation of our object """ - return '<git.%s "%s">' % (self.__class__.__name__, self.id) + return '<git.%s "%s">' % (self.__class__.__name__, self.sha) + + @property + def data_stream(self): + """ + Returns + File Object compatible stream to the uncompressed raw data of the object + """ + proc = self.repo.git.cat_file(self.type, self.sha, as_process=True) + return utils.ProcessStreamAdapter(proc, "stdout") + def stream_data(self, ostream): + """ + Writes our data directly to the given output stream + + ``ostream`` + File object compatible stream object. + + Returns + self + """ + self.repo.git.cat_file(self.type, self.sha, output_stream=ostream) + return self class IndexObject(Object): """ @@ -115,13 +148,13 @@ class IndexObject(Object): """ __slots__ = ("path", "mode") - def __init__(self, repo, id, mode=None, path=None): + def __init__(self, repo, sha, mode=None, path=None): """ Initialize a newly instanced IndexObject ``repo`` is the Repo we are located in - ``id`` : string + ``sha`` : string is the git object id as hex sha ``mode`` : int @@ -135,7 +168,7 @@ class IndexObject(Object): Path may not be set of the index object has been created directly as it cannot be retrieved without knowing the parent tree. """ - super(IndexObject, self).__init__(repo, id) + super(IndexObject, self).__init__(repo, sha) self._set_self_from_args_(locals()) if isinstance(mode, basestring): self.mode = self._mode_str_to_int(mode) @@ -162,5 +195,21 @@ class IndexObject(Object): mode += int(char) << iteration*3 # END for each char return mode - + + @property + def name(self): + """ + Returns + Name portion of the path, effectively being the basename + """ + return os.path.basename(self.path) + + @property + def abspath(self): + """ + Returns + Absolute path to this index object in the file system ( as opposed to the + .path field which is a path relative to the git repository ) + """ + return os.path.join(self.repo.git.git_dir, self.path) diff --git a/lib/git/objects/blob.py b/lib/git/objects/blob.py index 88ca73d6..11dee323 100644 --- a/lib/git/objects/blob.py +++ b/lib/git/objects/blob.py @@ -33,4 +33,4 @@ class Blob(base.IndexObject): def __repr__(self): - return '<git.Blob "%s">' % self.id + return '<git.Blob "%s">' % self.sha diff --git a/lib/git/objects/commit.py b/lib/git/objects/commit.py index 847f4dec..80b3ad23 100644 --- a/lib/git/objects/commit.py +++ b/lib/git/objects/commit.py @@ -11,7 +11,7 @@ from tree import Tree import base import utils -class Commit(base.Object, Iterable): +class Commit(base.Object, Iterable, diff.Diffable): """ Wraps a git Commit object. @@ -23,8 +23,9 @@ class Commit(base.Object, Iterable): type = "commit" __slots__ = ("tree", "author", "authored_date", "committer", "committed_date", "message", "parents") + _id_attribute_ = "sha" - def __init__(self, repo, id, tree=None, author=None, authored_date=None, + def __init__(self, repo, sha, tree=None, author=None, authored_date=None, committer=None, committed_date=None, message=None, parents=None): """ Instantiate a new Commit. All keyword arguments taking None as default will @@ -32,7 +33,7 @@ class Commit(base.Object, Iterable): The parameter documentation indicates the type of the argument after a colon ':'. - ``id`` + ``sha`` is the sha id of the commit or a ref ``parents`` : tuple( Commit, ... ) @@ -61,15 +62,15 @@ class Commit(base.Object, Iterable): Returns git.Commit """ - super(Commit,self).__init__(repo, id) + super(Commit,self).__init__(repo, sha) self._set_self_from_args_(locals()) if parents is not None: self.parents = tuple( self.__class__(repo, p) for p in parents ) # END for each parent to convert - if self.id and tree is not None: - self.tree = Tree(repo, id=tree, path='') + if self.sha and tree is not None: + self.tree = Tree(repo, tree, path='') # END id to tree conversion def _set_cache_(self, attr): @@ -81,8 +82,8 @@ class Commit(base.Object, Iterable): if attr in Commit.__slots__: # prepare our data lines to match rev-list data_lines = self.data.splitlines() - data_lines.insert(0, "commit %s" % self.id) - temp = self._iter_from_process_or_stream(self.repo, iter(data_lines)).next() + data_lines.insert(0, "commit %s" % self.sha) + temp = self._iter_from_process_or_stream(self.repo, iter(data_lines), False).next() self.parents = temp.parents self.tree = temp.tree self.author = temp.author @@ -101,27 +102,30 @@ class Commit(base.Object, Iterable): """ return self.message.split('\n', 1)[0] - @classmethod - def count(cls, repo, rev, paths='', **kwargs): + def count(self, paths='', **kwargs): """ - Count the number of commits reachable from this revision - - ``repo`` - is the Repo - - ``rev`` - revision specifier, see git-rev-parse for viable options + Count the number of commits reachable from this commit ``paths`` is an optinal path or a list of paths restricting the return value to commits actually containing the paths ``kwargs`` - Additional options to be passed to git-rev-list + Additional options to be passed to git-rev-list. They must not alter + the ouput style of the command, or parsing will yield incorrect results Returns int """ - return len(repo.git.rev_list(rev, '--', paths, **kwargs).strip().splitlines()) + return len(self.repo.git.rev_list(self.sha, '--', paths, **kwargs).strip().splitlines()) + + @property + def name_rev(self): + """ + Returns + String describing the commits hex sha based on the closest Reference. + Mostly useful for UI purposes + """ + return self.repo.git.name_rev(self) @classmethod def iter_items(cls, repo, rev, paths='', **kwargs): @@ -150,9 +154,8 @@ class Commit(base.Object, Iterable): options = {'pretty': 'raw', 'as_process' : True } options.update(kwargs) - # the test system might confront us with string values - proc = repo.git.rev_list(rev, '--', paths, **options) - return cls._iter_from_process_or_stream(repo, proc) + return cls._iter_from_process_or_stream(repo, proc, True) def iter_parents(self, paths='', **kwargs): """ @@ -176,60 +179,6 @@ class Commit(base.Object, Iterable): return self.iter_items( self.repo, self, paths, **kwargs ) - @classmethod - def diff(cls, repo, a, b=None, paths=None): - """ - Creates diffs between a tree and the index or between two trees: - - ``repo`` - is the Repo - - ``a`` - is a named commit - - ``b`` - is an optional named commit. Passing a list assumes you - wish to omit the second named commit and limit the diff to the - given paths. - - ``paths`` - is a list of paths to limit the diff to. - - Returns - git.Diff[]:: - - between tree and the index if only a is given - between two trees if a and b are given and are commits - """ - paths = paths or [] - - if isinstance(b, list): - paths = b - b = None - - if paths: - paths.insert(0, "--") - - if b: - paths.insert(0, b) - paths.insert(0, a) - text = repo.git.diff('-M', full_index=True, *paths) - return diff.Diff._list_from_string(repo, text) - - @property - def diffs(self): - """ - Returns - git.Diff[] - Diffs between this commit and its first parent or all changes if this - commit is the first commit and has no parent. - """ - if not self.parents: - d = self.repo.git.show(self.id, '-M', full_index=True, pretty='raw') - return diff.Diff._list_from_string(self.repo, d) - else: - return self.diff(self.repo, self.parents[0].id, self.id) - @property def stats(self): """ @@ -240,18 +189,18 @@ class Commit(base.Object, Iterable): git.Stats """ if not self.parents: - text = self.repo.git.diff_tree(self.id, '--', numstat=True, root=True) + text = self.repo.git.diff_tree(self.sha, '--', numstat=True, root=True) text2 = "" for line in text.splitlines()[1:]: (insertions, deletions, filename) = line.split("\t") text2 += "%s\t%s\t%s\n" % (insertions, deletions, filename) text = text2 else: - text = self.repo.git.diff(self.parents[0].id, self.id, '--', numstat=True) + text = self.repo.git.diff(self.parents[0].sha, self.sha, '--', numstat=True) return stats.Stats._list_from_string(self.repo, text) @classmethod - def _iter_from_process_or_stream(cls, repo, proc_or_stream): + def _iter_from_process_or_stream(cls, repo, proc_or_stream, from_rev_list): """ Parse out commit information into a list of Commit objects @@ -261,6 +210,9 @@ class Commit(base.Object, Iterable): ``proc`` git-rev-list process instance (raw format) + ``from_rev_list`` + If True, the stream was created by rev-list in which case we parse + the message differently Returns iterator returning Commit objects """ @@ -269,8 +221,9 @@ class Commit(base.Object, Iterable): stream = proc_or_stream.stdout for line in stream: - id = line.split()[1] - assert line.split()[0] == "commit" + commit_tokens = line.split() + id = commit_tokens[1] + assert commit_tokens[0] == "commit" tree = stream.next().split()[1] parents = [] @@ -290,24 +243,31 @@ class Commit(base.Object, Iterable): stream.next() message_lines = [] - next_line = None - for msg_line in stream: - if not msg_line.startswith(' '): - break - # END abort message reading - message_lines.append(msg_line.strip()) - # END while there are message lines + if from_rev_list: + for msg_line in stream: + if not msg_line.startswith(' '): + # and forget about this empty marker + break + # END abort message reading + # strip leading 4 spaces + message_lines.append(msg_line[4:]) + # END while there are message lines + else: + # a stream from our data simply gives us the plain message + for msg_line in stream: + message_lines.append(msg_line) + # END message parsing message = '\n'.join(message_lines) - yield Commit(repo, id=id, parents=tuple(parents), tree=tree, author=author, authored_date=authored_date, + yield Commit(repo, id, parents=tuple(parents), tree=tree, author=author, authored_date=authored_date, committer=committer, committed_date=committed_date, message=message) # END for each line in stream def __str__(self): """ Convert commit to string which is SHA1 """ - return self.id + return self.sha def __repr__(self): - return '<git.Commit "%s">' % self.id + return '<git.Commit "%s">' % self.sha diff --git a/lib/git/objects/tag.py b/lib/git/objects/tag.py index f54d4b64..c329edf7 100644 --- a/lib/git/objects/tag.py +++ b/lib/git/objects/tag.py @@ -17,7 +17,7 @@ class TagObject(base.Object): type = "tag" __slots__ = ( "object", "tag", "tagger", "tagged_date", "message" ) - def __init__(self, repo, id, object=None, tag=None, + def __init__(self, repo, sha, object=None, tag=None, tagger=None, tagged_date=None, message=None): """ Initialize a tag object with additional data @@ -25,7 +25,7 @@ class TagObject(base.Object): ``repo`` repository this object is located in - ``id`` + ``sha`` SHA1 or ref suitable for git-rev-parse ``object`` @@ -41,7 +41,7 @@ class TagObject(base.Object): is the DateTime of the tag creation - use time.gmtime to convert it into a different format """ - super(TagObject, self).__init__(repo, id ) + super(TagObject, self).__init__(repo, sha ) self._set_self_from_args_(locals()) def _set_cache_(self, attr): @@ -60,8 +60,13 @@ class TagObject(base.Object): tagger_info = lines[3][7:]# tagger <actor> <date> self.tagger, self.tagged_date = utils.parse_actor_and_date(tagger_info) - # line 4 empty - check git source to figure out purpose - self.message = "\n".join(lines[5:]) + # line 4 empty - it could mark the beginning of the next header + # in csse there really is no message, it would not exist. Otherwise + # a newline separates header from message + if len(lines) > 5: + self.message = "\n".join(lines[5:]) + else: + self.message = '' # END check our attributes else: super(TagObject, self)._set_cache_(attr) diff --git a/lib/git/objects/tree.py b/lib/git/objects/tree.py index abfa9622..27bd84d0 100644 --- a/lib/git/objects/tree.py +++ b/lib/git/objects/tree.py @@ -8,6 +8,7 @@ import os import blob import base import binascii +import git.diff as diff def sha_to_hex(sha): """Takes a string and returns the hex of the sha within""" @@ -15,7 +16,7 @@ def sha_to_hex(sha): assert len(hexsha) == 40, "Incorrect length of sha1 string: %d" % hexsha return hexsha -class Tree(base.IndexObject): +class Tree(base.IndexObject, diff.Diffable): """ Tress represent a ordered list of Blobs and other Trees. Hence it can be accessed like a list. @@ -37,13 +38,14 @@ class Tree(base.IndexObject): __slots__ = "_cache" # using ascii codes for comparison - ascii_commit_id = (0x31 << 4) + 0x36 - ascii_blob_id = (0x31 << 4) + 0x30 - ascii_tree_id = (0x34 << 4) + 0x30 + commit_id = 016 + blob_id = 010 + symlink_id = 012 + tree_id = 040 - def __init__(self, repo, id, mode=0, path=None): - super(Tree, self).__init__(repo, id, mode, path) + def __init__(self, repo, sha, mode=0, path=None): + super(Tree, self).__init__(repo, sha, mode, path) def _set_cache_(self, attr): if attr == "_cache": @@ -87,8 +89,8 @@ class Tree(base.IndexObject): mode = 0 mode_boundary = i + 6 - # keep it ascii - we compare against the respective values - type_id = (ord(data[i])<<4) + ord(data[i+1]) + # read type + type_id = ((ord(data[i])-ord_zero)<<3) + (ord(data[i+1])-ord_zero) i += 2 while data[i] != ' ': @@ -108,18 +110,20 @@ class Tree(base.IndexObject): i += 1 # END while not reached NULL name = data[ns:i] + path = os.path.join(self.path, name) # byte is NULL, get next 20 i += 1 sha = data[i:i+20] i = i + 20 + mode |= type_id<<12 hexsha = sha_to_hex(sha) - if type_id == self.ascii_blob_id: - yield blob.Blob(self.repo, hexsha, mode, name) - elif type_id == self.ascii_tree_id: - yield Tree(self.repo, hexsha, mode, name) - elif type_id == self.ascii_commit_id: + if type_id == self.blob_id or type_id == self.symlink_id: + yield blob.Blob(self.repo, hexsha, mode, path) + elif type_id == self.tree_id: + yield Tree(self.repo, hexsha, mode, path) + elif type_id == self.commit_id: # todo yield None else: @@ -148,29 +152,28 @@ class Tree(base.IndexObject): def __repr__(self): - return '<git.Tree "%s">' % self.id + return '<git.Tree "%s">' % self.sha @classmethod - def _iter_recursive(cls, repo, tree, cur_depth, max_depth, predicate ): + def _iter_recursive(cls, repo, tree, cur_depth, max_depth, predicate, prune ): for obj in tree: - # adjust path to be complete - obj.path = os.path.join(tree.path, obj.path) - if not predicate(obj): - continue - yield obj - if obj.type == "tree" and ( max_depth < 0 or cur_depth+1 <= max_depth ): - for recursive_obj in cls._iter_recursive( repo, obj, cur_depth+1, max_depth, predicate ): + if predicate(obj): + yield obj + if obj.type == "tree" and ( max_depth < 0 or cur_depth+1 <= max_depth ) and not prune(obj): + for recursive_obj in cls._iter_recursive( repo, obj, cur_depth+1, max_depth, predicate, prune ): yield recursive_obj # END for each recursive object # END if we may enter recursion # END for each object - def traverse(self, max_depth=-1, predicate = lambda i: True): + def traverse(self, max_depth=-1, predicate = lambda i: True, prune = lambda t: False): """ Returns + Iterator to traverse the tree recursively up to the given level. - The iterator returns Blob and Tree objects + The iterator returns Blob and Tree objects with paths relative to their + repository. ``max_depth`` @@ -181,8 +184,13 @@ class Tree(base.IndexObject): ``predicate`` If predicate(item) returns True, item will be returned by iterator + + ``prune`` + + If prune(tree) returns True, the traversal will not continue into the + given tree object. """ - return self._iter_recursive( self.repo, self, 0, max_depth, predicate ) + return self._iter_recursive( self.repo, self, 0, max_depth, predicate, prune ) @property def trees(self): @@ -218,7 +226,7 @@ class Tree(base.IndexObject): if isinstance(item, basestring): # compatability for obj in self._cache: - if obj.path == item: + if obj.name == item: return obj # END for each obj raise KeyError( "Blob or Tree named %s not found" % item ) diff --git a/lib/git/objects/utils.py b/lib/git/objects/utils.py index 367ed2b7..7bb4e8e2 100644 --- a/lib/git/objects/utils.py +++ b/lib/git/objects/utils.py @@ -52,3 +52,21 @@ def parse_actor_and_date(line): m = _re_actor_epoch.search(line) actor, epoch = m.groups() return (Actor._from_string(actor), int(epoch)) + + + +class ProcessStreamAdapter(object): + """ + Class wireing all calls to the contained Process instance. + + Use this type to hide the underlying process to provide access only to a specified + stream. The process is usually wrapped into an AutoInterrupt class to kill + it if the instance goes out of scope. + """ + __slots__ = ("_proc", "_stream") + def __init__(self, process, stream_name): + self._proc = process + self._stream = getattr(process, stream_name) + + def __getattr__(self, attr): + return getattr(self._stream, attr) diff --git a/lib/git/odict.py b/lib/git/odict.py new file mode 100644 index 00000000..2c8391d7 --- /dev/null +++ b/lib/git/odict.py @@ -0,0 +1,1399 @@ +# odict.py +# An Ordered Dictionary object +# Copyright (C) 2005 Nicola Larosa, Michael Foord +# E-mail: nico AT tekNico DOT net, fuzzyman AT voidspace DOT org DOT uk + +# This software is licensed under the terms of the BSD license. +# http://www.voidspace.org.uk/python/license.shtml +# Basically you're free to copy, modify, distribute and relicense it, +# So long as you keep a copy of the license with it. + +# Documentation at http://www.voidspace.org.uk/python/odict.html +# For information about bugfixes, updates and support, please join the +# Pythonutils mailing list: +# http://groups.google.com/group/pythonutils/ +# Comments, suggestions and bug reports welcome. + +"""A dict that keeps keys in insertion order""" +from __future__ import generators + +__author__ = ('Nicola Larosa <nico-NoSp@m-tekNico.net>,' + 'Michael Foord <fuzzyman AT voidspace DOT org DOT uk>') + +__docformat__ = "restructuredtext en" + +__revision__ = '$Id: odict.py 129 2005-09-12 18:15:28Z teknico $' + +__version__ = '0.2.2' + +__all__ = ['OrderedDict', 'SequenceOrderedDict'] + +import sys +INTP_VER = sys.version_info[:2] +if INTP_VER < (2, 2): + raise RuntimeError("Python v.2.2 or later required") + +import types, warnings + +class OrderedDict(dict): + """ + A class of dictionary that keeps the insertion order of keys. + + All appropriate methods return keys, items, or values in an ordered way. + + All normal dictionary methods are available. Update and comparison is + restricted to other OrderedDict objects. + + Various sequence methods are available, including the ability to explicitly + mutate the key ordering. + + __contains__ tests: + + >>> d = OrderedDict(((1, 3),)) + >>> 1 in d + 1 + >>> 4 in d + 0 + + __getitem__ tests: + + >>> OrderedDict(((1, 3), (3, 2), (2, 1)))[2] + 1 + >>> OrderedDict(((1, 3), (3, 2), (2, 1)))[4] + Traceback (most recent call last): + KeyError: 4 + + __len__ tests: + + >>> len(OrderedDict()) + 0 + >>> len(OrderedDict(((1, 3), (3, 2), (2, 1)))) + 3 + + get tests: + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.get(1) + 3 + >>> d.get(4) is None + 1 + >>> d.get(4, 5) + 5 + >>> d + OrderedDict([(1, 3), (3, 2), (2, 1)]) + + has_key tests: + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.has_key(1) + 1 + >>> d.has_key(4) + 0 + """ + + def __init__(self, init_val=(), strict=False): + """ + Create a new ordered dictionary. Cannot init from a normal dict, + nor from kwargs, since items order is undefined in those cases. + + If the ``strict`` keyword argument is ``True`` (``False`` is the + default) then when doing slice assignment - the ``OrderedDict`` you are + assigning from *must not* contain any keys in the remaining dict. + + >>> OrderedDict() + OrderedDict([]) + >>> OrderedDict({1: 1}) + Traceback (most recent call last): + TypeError: undefined order, cannot get items from dict + >>> OrderedDict({1: 1}.items()) + OrderedDict([(1, 1)]) + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d + OrderedDict([(1, 3), (3, 2), (2, 1)]) + >>> OrderedDict(d) + OrderedDict([(1, 3), (3, 2), (2, 1)]) + """ + self.strict = strict + dict.__init__(self) + if isinstance(init_val, OrderedDict): + self._sequence = init_val.keys() + dict.update(self, init_val) + elif isinstance(init_val, dict): + # we lose compatibility with other ordered dict types this way + raise TypeError('undefined order, cannot get items from dict') + else: + self._sequence = [] + self.update(init_val) + +### Special methods ### + + def __delitem__(self, key): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> del d[3] + >>> d + OrderedDict([(1, 3), (2, 1)]) + >>> del d[3] + Traceback (most recent call last): + KeyError: 3 + >>> d[3] = 2 + >>> d + OrderedDict([(1, 3), (2, 1), (3, 2)]) + >>> del d[0:1] + >>> d + OrderedDict([(2, 1), (3, 2)]) + """ + if isinstance(key, types.SliceType): + # FIXME: efficiency? + keys = self._sequence[key] + for entry in keys: + dict.__delitem__(self, entry) + del self._sequence[key] + else: + # do the dict.__delitem__ *first* as it raises + # the more appropriate error + dict.__delitem__(self, key) + self._sequence.remove(key) + + def __eq__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d == OrderedDict(d) + True + >>> d == OrderedDict(((1, 3), (2, 1), (3, 2))) + False + >>> d == OrderedDict(((1, 0), (3, 2), (2, 1))) + False + >>> d == OrderedDict(((0, 3), (3, 2), (2, 1))) + False + >>> d == dict(d) + False + >>> d == False + False + """ + if isinstance(other, OrderedDict): + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() == other.items()) + else: + return False + + def __lt__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> c = OrderedDict(((0, 3), (3, 2), (2, 1))) + >>> c < d + True + >>> d < c + False + >>> d < dict(c) + Traceback (most recent call last): + TypeError: Can only compare with other OrderedDicts + """ + if not isinstance(other, OrderedDict): + raise TypeError('Can only compare with other OrderedDicts') + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() < other.items()) + + def __le__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> c = OrderedDict(((0, 3), (3, 2), (2, 1))) + >>> e = OrderedDict(d) + >>> c <= d + True + >>> d <= c + False + >>> d <= dict(c) + Traceback (most recent call last): + TypeError: Can only compare with other OrderedDicts + >>> d <= e + True + """ + if not isinstance(other, OrderedDict): + raise TypeError('Can only compare with other OrderedDicts') + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() <= other.items()) + + def __ne__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d != OrderedDict(d) + False + >>> d != OrderedDict(((1, 3), (2, 1), (3, 2))) + True + >>> d != OrderedDict(((1, 0), (3, 2), (2, 1))) + True + >>> d == OrderedDict(((0, 3), (3, 2), (2, 1))) + False + >>> d != dict(d) + True + >>> d != False + True + """ + if isinstance(other, OrderedDict): + # FIXME: efficiency? + # Generate both item lists for each compare + return not (self.items() == other.items()) + else: + return True + + def __gt__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> c = OrderedDict(((0, 3), (3, 2), (2, 1))) + >>> d > c + True + >>> c > d + False + >>> d > dict(c) + Traceback (most recent call last): + TypeError: Can only compare with other OrderedDicts + """ + if not isinstance(other, OrderedDict): + raise TypeError('Can only compare with other OrderedDicts') + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() > other.items()) + + def __ge__(self, other): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> c = OrderedDict(((0, 3), (3, 2), (2, 1))) + >>> e = OrderedDict(d) + >>> c >= d + False + >>> d >= c + True + >>> d >= dict(c) + Traceback (most recent call last): + TypeError: Can only compare with other OrderedDicts + >>> e >= d + True + """ + if not isinstance(other, OrderedDict): + raise TypeError('Can only compare with other OrderedDicts') + # FIXME: efficiency? + # Generate both item lists for each compare + return (self.items() >= other.items()) + + def __repr__(self): + """ + Used for __repr__ and __str__ + + >>> r1 = repr(OrderedDict((('a', 'b'), ('c', 'd'), ('e', 'f')))) + >>> r1 + "OrderedDict([('a', 'b'), ('c', 'd'), ('e', 'f')])" + >>> r2 = repr(OrderedDict((('a', 'b'), ('e', 'f'), ('c', 'd')))) + >>> r2 + "OrderedDict([('a', 'b'), ('e', 'f'), ('c', 'd')])" + >>> r1 == str(OrderedDict((('a', 'b'), ('c', 'd'), ('e', 'f')))) + True + >>> r2 == str(OrderedDict((('a', 'b'), ('e', 'f'), ('c', 'd')))) + True + """ + return '%s([%s])' % (self.__class__.__name__, ', '.join( + ['(%r, %r)' % (key, self[key]) for key in self._sequence])) + + def __setitem__(self, key, val): + """ + Allows slice assignment, so long as the slice is an OrderedDict + >>> d = OrderedDict() + >>> d['a'] = 'b' + >>> d['b'] = 'a' + >>> d[3] = 12 + >>> d + OrderedDict([('a', 'b'), ('b', 'a'), (3, 12)]) + >>> d[:] = OrderedDict(((1, 2), (2, 3), (3, 4))) + >>> d + OrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d[::2] = OrderedDict(((7, 8), (9, 10))) + >>> d + OrderedDict([(7, 8), (2, 3), (9, 10)]) + >>> d = OrderedDict(((0, 1), (1, 2), (2, 3), (3, 4))) + >>> d[1:3] = OrderedDict(((1, 2), (5, 6), (7, 8))) + >>> d + OrderedDict([(0, 1), (1, 2), (5, 6), (7, 8), (3, 4)]) + >>> d = OrderedDict(((0, 1), (1, 2), (2, 3), (3, 4)), strict=True) + >>> d[1:3] = OrderedDict(((1, 2), (5, 6), (7, 8))) + >>> d + OrderedDict([(0, 1), (1, 2), (5, 6), (7, 8), (3, 4)]) + + >>> a = OrderedDict(((0, 1), (1, 2), (2, 3)), strict=True) + >>> a[3] = 4 + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[::1] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[:2] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4), (4, 5)]) + Traceback (most recent call last): + ValueError: slice assignment must be from unique keys + >>> a = OrderedDict(((0, 1), (1, 2), (2, 3))) + >>> a[3] = 4 + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[::1] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[:2] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a + OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a[::-1] = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> a + OrderedDict([(3, 4), (2, 3), (1, 2), (0, 1)]) + + >>> d = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> d[:1] = 3 + Traceback (most recent call last): + TypeError: slice assignment requires an OrderedDict + + >>> d = OrderedDict([(0, 1), (1, 2), (2, 3), (3, 4)]) + >>> d[:1] = OrderedDict([(9, 8)]) + >>> d + OrderedDict([(9, 8), (1, 2), (2, 3), (3, 4)]) + """ + if isinstance(key, types.SliceType): + if not isinstance(val, OrderedDict): + # FIXME: allow a list of tuples? + raise TypeError('slice assignment requires an OrderedDict') + keys = self._sequence[key] + # NOTE: Could use ``range(*key.indices(len(self._sequence)))`` + indexes = range(len(self._sequence))[key] + if key.step is None: + # NOTE: new slice may not be the same size as the one being + # overwritten ! + # NOTE: What is the algorithm for an impossible slice? + # e.g. d[5:3] + pos = key.start or 0 + del self[key] + newkeys = val.keys() + for k in newkeys: + if k in self: + if self.strict: + raise ValueError('slice assignment must be from ' + 'unique keys') + else: + # NOTE: This removes duplicate keys *first* + # so start position might have changed? + del self[k] + self._sequence = (self._sequence[:pos] + newkeys + + self._sequence[pos:]) + dict.update(self, val) + else: + # extended slice - length of new slice must be the same + # as the one being replaced + if len(keys) != len(val): + raise ValueError('attempt to assign sequence of size %s ' + 'to extended slice of size %s' % (len(val), len(keys))) + # FIXME: efficiency? + del self[key] + item_list = zip(indexes, val.items()) + # smallest indexes first - higher indexes not guaranteed to + # exist + item_list.sort() + for pos, (newkey, newval) in item_list: + if self.strict and newkey in self: + raise ValueError('slice assignment must be from unique' + ' keys') + self.insert(pos, newkey, newval) + else: + if key not in self: + self._sequence.append(key) + dict.__setitem__(self, key, val) + + def __getitem__(self, key): + """ + Allows slicing. Returns an OrderedDict if you slice. + >>> b = OrderedDict([(7, 0), (6, 1), (5, 2), (4, 3), (3, 4), (2, 5), (1, 6)]) + >>> b[::-1] + OrderedDict([(1, 6), (2, 5), (3, 4), (4, 3), (5, 2), (6, 1), (7, 0)]) + >>> b[2:5] + OrderedDict([(5, 2), (4, 3), (3, 4)]) + >>> type(b[2:4]) + <class '__main__.OrderedDict'> + """ + if isinstance(key, types.SliceType): + # FIXME: does this raise the error we want? + keys = self._sequence[key] + # FIXME: efficiency? + return OrderedDict([(entry, self[entry]) for entry in keys]) + else: + return dict.__getitem__(self, key) + + __str__ = __repr__ + + def __setattr__(self, name, value): + """ + Implemented so that accesses to ``sequence`` raise a warning and are + diverted to the new ``setkeys`` method. + """ + if name == 'sequence': + warnings.warn('Use of the sequence attribute is deprecated.' + ' Use the keys method instead.', DeprecationWarning) + # NOTE: doesn't return anything + self.setkeys(value) + else: + # FIXME: do we want to allow arbitrary setting of attributes? + # Or do we want to manage it? + object.__setattr__(self, name, value) + + def __getattr__(self, name): + """ + Implemented so that access to ``sequence`` raises a warning. + + >>> d = OrderedDict() + >>> d.sequence + [] + """ + if name == 'sequence': + warnings.warn('Use of the sequence attribute is deprecated.' + ' Use the keys method instead.', DeprecationWarning) + # NOTE: Still (currently) returns a direct reference. Need to + # because code that uses sequence will expect to be able to + # mutate it in place. + return self._sequence + else: + # raise the appropriate error + raise AttributeError("OrderedDict has no '%s' attribute" % name) + + def __deepcopy__(self, memo): + """ + To allow deepcopy to work with OrderedDict. + + >>> from copy import deepcopy + >>> a = OrderedDict([(1, 1), (2, 2), (3, 3)]) + >>> a['test'] = {} + >>> b = deepcopy(a) + >>> b == a + True + >>> b is a + False + >>> a['test'] is b['test'] + False + """ + from copy import deepcopy + return self.__class__(deepcopy(self.items(), memo), self.strict) + + +### Read-only methods ### + + def copy(self): + """ + >>> OrderedDict(((1, 3), (3, 2), (2, 1))).copy() + OrderedDict([(1, 3), (3, 2), (2, 1)]) + """ + return OrderedDict(self) + + def items(self): + """ + ``items`` returns a list of tuples representing all the + ``(key, value)`` pairs in the dictionary. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.items() + [(1, 3), (3, 2), (2, 1)] + >>> d.clear() + >>> d.items() + [] + """ + return zip(self._sequence, self.values()) + + def keys(self): + """ + Return a list of keys in the ``OrderedDict``. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.keys() + [1, 3, 2] + """ + return self._sequence[:] + + def values(self, values=None): + """ + Return a list of all the values in the OrderedDict. + + Optionally you can pass in a list of values, which will replace the + current list. The value list must be the same len as the OrderedDict. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.values() + [3, 2, 1] + """ + return [self[key] for key in self._sequence] + + def iteritems(self): + """ + >>> ii = OrderedDict(((1, 3), (3, 2), (2, 1))).iteritems() + >>> ii.next() + (1, 3) + >>> ii.next() + (3, 2) + >>> ii.next() + (2, 1) + >>> ii.next() + Traceback (most recent call last): + StopIteration + """ + def make_iter(self=self): + keys = self.iterkeys() + while True: + key = keys.next() + yield (key, self[key]) + return make_iter() + + def iterkeys(self): + """ + >>> ii = OrderedDict(((1, 3), (3, 2), (2, 1))).iterkeys() + >>> ii.next() + 1 + >>> ii.next() + 3 + >>> ii.next() + 2 + >>> ii.next() + Traceback (most recent call last): + StopIteration + """ + return iter(self._sequence) + + __iter__ = iterkeys + + def itervalues(self): + """ + >>> iv = OrderedDict(((1, 3), (3, 2), (2, 1))).itervalues() + >>> iv.next() + 3 + >>> iv.next() + 2 + >>> iv.next() + 1 + >>> iv.next() + Traceback (most recent call last): + StopIteration + """ + def make_iter(self=self): + keys = self.iterkeys() + while True: + yield self[keys.next()] + return make_iter() + +### Read-write methods ### + + def clear(self): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.clear() + >>> d + OrderedDict([]) + """ + dict.clear(self) + self._sequence = [] + + def pop(self, key, *args): + """ + No dict.pop in Python 2.2, gotta reimplement it + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.pop(3) + 2 + >>> d + OrderedDict([(1, 3), (2, 1)]) + >>> d.pop(4) + Traceback (most recent call last): + KeyError: 4 + >>> d.pop(4, 0) + 0 + >>> d.pop(4, 0, 1) + Traceback (most recent call last): + TypeError: pop expected at most 2 arguments, got 3 + """ + if len(args) > 1: + raise TypeError, ('pop expected at most 2 arguments, got %s' % + (len(args) + 1)) + if key in self: + val = self[key] + del self[key] + else: + try: + val = args[0] + except IndexError: + raise KeyError(key) + return val + + def popitem(self, i=-1): + """ + Delete and return an item specified by index, not a random one as in + dict. The index is -1 by default (the last item). + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.popitem() + (2, 1) + >>> d + OrderedDict([(1, 3), (3, 2)]) + >>> d.popitem(0) + (1, 3) + >>> OrderedDict().popitem() + Traceback (most recent call last): + KeyError: 'popitem(): dictionary is empty' + >>> d.popitem(2) + Traceback (most recent call last): + IndexError: popitem(): index 2 not valid + """ + if not self._sequence: + raise KeyError('popitem(): dictionary is empty') + try: + key = self._sequence[i] + except IndexError: + raise IndexError('popitem(): index %s not valid' % i) + return (key, self.pop(key)) + + def setdefault(self, key, defval = None): + """ + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.setdefault(1) + 3 + >>> d.setdefault(4) is None + True + >>> d + OrderedDict([(1, 3), (3, 2), (2, 1), (4, None)]) + >>> d.setdefault(5, 0) + 0 + >>> d + OrderedDict([(1, 3), (3, 2), (2, 1), (4, None), (5, 0)]) + """ + if key in self: + return self[key] + else: + self[key] = defval + return defval + + def update(self, from_od): + """ + Update from another OrderedDict or sequence of (key, value) pairs + + >>> d = OrderedDict(((1, 0), (0, 1))) + >>> d.update(OrderedDict(((1, 3), (3, 2), (2, 1)))) + >>> d + OrderedDict([(1, 3), (0, 1), (3, 2), (2, 1)]) + >>> d.update({4: 4}) + Traceback (most recent call last): + TypeError: undefined order, cannot get items from dict + >>> d.update((4, 4)) + Traceback (most recent call last): + TypeError: cannot convert dictionary update sequence element "4" to a 2-item sequence + """ + if isinstance(from_od, OrderedDict): + for key, val in from_od.items(): + self[key] = val + elif isinstance(from_od, dict): + # we lose compatibility with other ordered dict types this way + raise TypeError('undefined order, cannot get items from dict') + else: + # FIXME: efficiency? + # sequence of 2-item sequences, or error + for item in from_od: + try: + key, val = item + except TypeError: + raise TypeError('cannot convert dictionary update' + ' sequence element "%s" to a 2-item sequence' % item) + self[key] = val + + def rename(self, old_key, new_key): + """ + Rename the key for a given value, without modifying sequence order. + + For the case where new_key already exists this raise an exception, + since if new_key exists, it is ambiguous as to what happens to the + associated values, and the position of new_key in the sequence. + + >>> od = OrderedDict() + >>> od['a'] = 1 + >>> od['b'] = 2 + >>> od.items() + [('a', 1), ('b', 2)] + >>> od.rename('b', 'c') + >>> od.items() + [('a', 1), ('c', 2)] + >>> od.rename('c', 'a') + Traceback (most recent call last): + ValueError: New key already exists: 'a' + >>> od.rename('d', 'b') + Traceback (most recent call last): + KeyError: 'd' + """ + if new_key == old_key: + # no-op + return + if new_key in self: + raise ValueError("New key already exists: %r" % new_key) + # rename sequence entry + value = self[old_key] + old_idx = self._sequence.index(old_key) + self._sequence[old_idx] = new_key + # rename internal dict entry + dict.__delitem__(self, old_key) + dict.__setitem__(self, new_key, value) + + def setitems(self, items): + """ + This method allows you to set the items in the dict. + + It takes a list of tuples - of the same sort returned by the ``items`` + method. + + >>> d = OrderedDict() + >>> d.setitems(((3, 1), (2, 3), (1, 2))) + >>> d + OrderedDict([(3, 1), (2, 3), (1, 2)]) + """ + self.clear() + # FIXME: this allows you to pass in an OrderedDict as well :-) + self.update(items) + + def setkeys(self, keys): + """ + ``setkeys`` all ows you to pass in a new list of keys which will + replace the current set. This must contain the same set of keys, but + need not be in the same order. + + If you pass in new keys that don't match, a ``KeyError`` will be + raised. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.keys() + [1, 3, 2] + >>> d.setkeys((1, 2, 3)) + >>> d + OrderedDict([(1, 3), (2, 1), (3, 2)]) + >>> d.setkeys(['a', 'b', 'c']) + Traceback (most recent call last): + KeyError: 'Keylist is not the same as current keylist.' + """ + # FIXME: Efficiency? (use set for Python 2.4 :-) + # NOTE: list(keys) rather than keys[:] because keys[:] returns + # a tuple, if keys is a tuple. + kcopy = list(keys) + kcopy.sort() + self._sequence.sort() + if kcopy != self._sequence: + raise KeyError('Keylist is not the same as current keylist.') + # NOTE: This makes the _sequence attribute a new object, instead + # of changing it in place. + # FIXME: efficiency? + self._sequence = list(keys) + + def setvalues(self, values): + """ + You can pass in a list of values, which will replace the + current list. The value list must be the same len as the OrderedDict. + + (Or a ``ValueError`` is raised.) + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.setvalues((1, 2, 3)) + >>> d + OrderedDict([(1, 1), (3, 2), (2, 3)]) + >>> d.setvalues([6]) + Traceback (most recent call last): + ValueError: Value list is not the same length as the OrderedDict. + """ + if len(values) != len(self): + # FIXME: correct error to raise? + raise ValueError('Value list is not the same length as the ' + 'OrderedDict.') + self.update(zip(self, values)) + +### Sequence Methods ### + + def index(self, key): + """ + Return the position of the specified key in the OrderedDict. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.index(3) + 1 + >>> d.index(4) + Traceback (most recent call last): + ValueError: list.index(x): x not in list + """ + return self._sequence.index(key) + + def insert(self, index, key, value): + """ + Takes ``index``, ``key``, and ``value`` as arguments. + + Sets ``key`` to ``value``, so that ``key`` is at position ``index`` in + the OrderedDict. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.insert(0, 4, 0) + >>> d + OrderedDict([(4, 0), (1, 3), (3, 2), (2, 1)]) + >>> d.insert(0, 2, 1) + >>> d + OrderedDict([(2, 1), (4, 0), (1, 3), (3, 2)]) + >>> d.insert(8, 8, 1) + >>> d + OrderedDict([(2, 1), (4, 0), (1, 3), (3, 2), (8, 1)]) + """ + if key in self: + # FIXME: efficiency? + del self[key] + self._sequence.insert(index, key) + dict.__setitem__(self, key, value) + + def reverse(self): + """ + Reverse the order of the OrderedDict. + + >>> d = OrderedDict(((1, 3), (3, 2), (2, 1))) + >>> d.reverse() + >>> d + OrderedDict([(2, 1), (3, 2), (1, 3)]) + """ + self._sequence.reverse() + + def sort(self, *args, **kwargs): + """ + Sort the key order in the OrderedDict. + + This method takes the same arguments as the ``list.sort`` method on + your version of Python. + + >>> d = OrderedDict(((4, 1), (2, 2), (3, 3), (1, 4))) + >>> d.sort() + >>> d + OrderedDict([(1, 4), (2, 2), (3, 3), (4, 1)]) + """ + self._sequence.sort(*args, **kwargs) + +class Keys(object): + # FIXME: should this object be a subclass of list? + """ + Custom object for accessing the keys of an OrderedDict. + + Can be called like the normal ``OrderedDict.keys`` method, but also + supports indexing and sequence methods. + """ + + def __init__(self, main): + self._main = main + + def __call__(self): + """Pretend to be the keys method.""" + return self._main._keys() + + def __getitem__(self, index): + """Fetch the key at position i.""" + # NOTE: this automatically supports slicing :-) + return self._main._sequence[index] + + def __setitem__(self, index, name): + """ + You cannot assign to keys, but you can do slice assignment to re-order + them. + + You can only do slice assignment if the new set of keys is a reordering + of the original set. + """ + if isinstance(index, types.SliceType): + # FIXME: efficiency? + # check length is the same + indexes = range(len(self._main._sequence))[index] + if len(indexes) != len(name): + raise ValueError('attempt to assign sequence of size %s ' + 'to slice of size %s' % (len(name), len(indexes))) + # check they are the same keys + # FIXME: Use set + old_keys = self._main._sequence[index] + new_keys = list(name) + old_keys.sort() + new_keys.sort() + if old_keys != new_keys: + raise KeyError('Keylist is not the same as current keylist.') + orig_vals = [self._main[k] for k in name] + del self._main[index] + vals = zip(indexes, name, orig_vals) + vals.sort() + for i, k, v in vals: + if self._main.strict and k in self._main: + raise ValueError('slice assignment must be from ' + 'unique keys') + self._main.insert(i, k, v) + else: + raise ValueError('Cannot assign to keys') + + ### following methods pinched from UserList and adapted ### + def __repr__(self): return repr(self._main._sequence) + + # FIXME: do we need to check if we are comparing with another ``Keys`` + # object? (like the __cast method of UserList) + def __lt__(self, other): return self._main._sequence < other + def __le__(self, other): return self._main._sequence <= other + def __eq__(self, other): return self._main._sequence == other + def __ne__(self, other): return self._main._sequence != other + def __gt__(self, other): return self._main._sequence > other + def __ge__(self, other): return self._main._sequence >= other + # FIXME: do we need __cmp__ as well as rich comparisons? + def __cmp__(self, other): return cmp(self._main._sequence, other) + + def __contains__(self, item): return item in self._main._sequence + def __len__(self): return len(self._main._sequence) + def __iter__(self): return self._main.iterkeys() + def count(self, item): return self._main._sequence.count(item) + def index(self, item, *args): return self._main._sequence.index(item, *args) + def reverse(self): self._main._sequence.reverse() + def sort(self, *args, **kwds): self._main._sequence.sort(*args, **kwds) + def __mul__(self, n): return self._main._sequence*n + __rmul__ = __mul__ + def __add__(self, other): return self._main._sequence + other + def __radd__(self, other): return other + self._main._sequence + + ## following methods not implemented for keys ## + def __delitem__(self, i): raise TypeError('Can\'t delete items from keys') + def __iadd__(self, other): raise TypeError('Can\'t add in place to keys') + def __imul__(self, n): raise TypeError('Can\'t multiply keys in place') + def append(self, item): raise TypeError('Can\'t append items to keys') + def insert(self, i, item): raise TypeError('Can\'t insert items into keys') + def pop(self, i=-1): raise TypeError('Can\'t pop items from keys') + def remove(self, item): raise TypeError('Can\'t remove items from keys') + def extend(self, other): raise TypeError('Can\'t extend keys') + +class Items(object): + """ + Custom object for accessing the items of an OrderedDict. + + Can be called like the normal ``OrderedDict.items`` method, but also + supports indexing and sequence methods. + """ + + def __init__(self, main): + self._main = main + + def __call__(self): + """Pretend to be the items method.""" + return self._main._items() + + def __getitem__(self, index): + """Fetch the item at position i.""" + if isinstance(index, types.SliceType): + # fetching a slice returns an OrderedDict + return self._main[index].items() + key = self._main._sequence[index] + return (key, self._main[key]) + + def __setitem__(self, index, item): + """Set item at position i to item.""" + if isinstance(index, types.SliceType): + # NOTE: item must be an iterable (list of tuples) + self._main[index] = OrderedDict(item) + else: + # FIXME: Does this raise a sensible error? + orig = self._main.keys[index] + key, value = item + if self._main.strict and key in self and (key != orig): + raise ValueError('slice assignment must be from ' + 'unique keys') + # delete the current one + del self._main[self._main._sequence[index]] + self._main.insert(index, key, value) + + def __delitem__(self, i): + """Delete the item at position i.""" + key = self._main._sequence[i] + if isinstance(i, types.SliceType): + for k in key: + # FIXME: efficiency? + del self._main[k] + else: + del self._main[key] + + ### following methods pinched from UserList and adapted ### + def __repr__(self): return repr(self._main.items()) + + # FIXME: do we need to check if we are comparing with another ``Items`` + # object? (like the __cast method of UserList) + def __lt__(self, other): return self._main.items() < other + def __le__(self, other): return self._main.items() <= other + def __eq__(self, other): return self._main.items() == other + def __ne__(self, other): return self._main.items() != other + def __gt__(self, other): return self._main.items() > other + def __ge__(self, other): return self._main.items() >= other + def __cmp__(self, other): return cmp(self._main.items(), other) + + def __contains__(self, item): return item in self._main.items() + def __len__(self): return len(self._main._sequence) # easier :-) + def __iter__(self): return self._main.iteritems() + def count(self, item): return self._main.items().count(item) + def index(self, item, *args): return self._main.items().index(item, *args) + def reverse(self): self._main.reverse() + def sort(self, *args, **kwds): self._main.sort(*args, **kwds) + def __mul__(self, n): return self._main.items()*n + __rmul__ = __mul__ + def __add__(self, other): return self._main.items() + other + def __radd__(self, other): return other + self._main.items() + + def append(self, item): + """Add an item to the end.""" + # FIXME: this is only append if the key isn't already present + key, value = item + self._main[key] = value + + def insert(self, i, item): + key, value = item + self._main.insert(i, key, value) + + def pop(self, i=-1): + key = self._main._sequence[i] + return (key, self._main.pop(key)) + + def remove(self, item): + key, value = item + try: + assert value == self._main[key] + except (KeyError, AssertionError): + raise ValueError('ValueError: list.remove(x): x not in list') + else: + del self._main[key] + + def extend(self, other): + # FIXME: is only a true extend if none of the keys already present + for item in other: + key, value = item + self._main[key] = value + + def __iadd__(self, other): + self.extend(other) + + ## following methods not implemented for items ## + + def __imul__(self, n): raise TypeError('Can\'t multiply items in place') + +class Values(object): + """ + Custom object for accessing the values of an OrderedDict. + + Can be called like the normal ``OrderedDict.values`` method, but also + supports indexing and sequence methods. + """ + + def __init__(self, main): + self._main = main + + def __call__(self): + """Pretend to be the values method.""" + return self._main._values() + + def __getitem__(self, index): + """Fetch the value at position i.""" + if isinstance(index, types.SliceType): + return [self._main[key] for key in self._main._sequence[index]] + else: + return self._main[self._main._sequence[index]] + + def __setitem__(self, index, value): + """ + Set the value at position i to value. + + You can only do slice assignment to values if you supply a sequence of + equal length to the slice you are replacing. + """ + if isinstance(index, types.SliceType): + keys = self._main._sequence[index] + if len(keys) != len(value): + raise ValueError('attempt to assign sequence of size %s ' + 'to slice of size %s' % (len(name), len(keys))) + # FIXME: efficiency? Would be better to calculate the indexes + # directly from the slice object + # NOTE: the new keys can collide with existing keys (or even + # contain duplicates) - these will overwrite + for key, val in zip(keys, value): + self._main[key] = val + else: + self._main[self._main._sequence[index]] = value + + ### following methods pinched from UserList and adapted ### + def __repr__(self): return repr(self._main.values()) + + # FIXME: do we need to check if we are comparing with another ``Values`` + # object? (like the __cast method of UserList) + def __lt__(self, other): return self._main.values() < other + def __le__(self, other): return self._main.values() <= other + def __eq__(self, other): return self._main.values() == other + def __ne__(self, other): return self._main.values() != other + def __gt__(self, other): return self._main.values() > other + def __ge__(self, other): return self._main.values() >= other + def __cmp__(self, other): return cmp(self._main.values(), other) + + def __contains__(self, item): return item in self._main.values() + def __len__(self): return len(self._main._sequence) # easier :-) + def __iter__(self): return self._main.itervalues() + def count(self, item): return self._main.values().count(item) + def index(self, item, *args): return self._main.values().index(item, *args) + + def reverse(self): + """Reverse the values""" + vals = self._main.values() + vals.reverse() + # FIXME: efficiency + self[:] = vals + + def sort(self, *args, **kwds): + """Sort the values.""" + vals = self._main.values() + vals.sort(*args, **kwds) + self[:] = vals + + def __mul__(self, n): return self._main.values()*n + __rmul__ = __mul__ + def __add__(self, other): return self._main.values() + other + def __radd__(self, other): return other + self._main.values() + + ## following methods not implemented for values ## + def __delitem__(self, i): raise TypeError('Can\'t delete items from values') + def __iadd__(self, other): raise TypeError('Can\'t add in place to values') + def __imul__(self, n): raise TypeError('Can\'t multiply values in place') + def append(self, item): raise TypeError('Can\'t append items to values') + def insert(self, i, item): raise TypeError('Can\'t insert items into values') + def pop(self, i=-1): raise TypeError('Can\'t pop items from values') + def remove(self, item): raise TypeError('Can\'t remove items from values') + def extend(self, other): raise TypeError('Can\'t extend values') + +class SequenceOrderedDict(OrderedDict): + """ + Experimental version of OrderedDict that has a custom object for ``keys``, + ``values``, and ``items``. + + These are callable sequence objects that work as methods, or can be + manipulated directly as sequences. + + Test for ``keys``, ``items`` and ``values``. + + >>> d = SequenceOrderedDict(((1, 2), (2, 3), (3, 4))) + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d.keys + [1, 2, 3] + >>> d.keys() + [1, 2, 3] + >>> d.setkeys((3, 2, 1)) + >>> d + SequenceOrderedDict([(3, 4), (2, 3), (1, 2)]) + >>> d.setkeys((1, 2, 3)) + >>> d.keys[0] + 1 + >>> d.keys[:] + [1, 2, 3] + >>> d.keys[-1] + 3 + >>> d.keys[-2] + 2 + >>> d.keys[0:2] = [2, 1] + >>> d + SequenceOrderedDict([(2, 3), (1, 2), (3, 4)]) + >>> d.keys.reverse() + >>> d.keys + [3, 1, 2] + >>> d.keys = [1, 2, 3] + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d.keys = [3, 1, 2] + >>> d + SequenceOrderedDict([(3, 4), (1, 2), (2, 3)]) + >>> a = SequenceOrderedDict() + >>> b = SequenceOrderedDict() + >>> a.keys == b.keys + 1 + >>> a['a'] = 3 + >>> a.keys == b.keys + 0 + >>> b['a'] = 3 + >>> a.keys == b.keys + 1 + >>> b['b'] = 3 + >>> a.keys == b.keys + 0 + >>> a.keys > b.keys + 0 + >>> a.keys < b.keys + 1 + >>> 'a' in a.keys + 1 + >>> len(b.keys) + 2 + >>> 'c' in d.keys + 0 + >>> 1 in d.keys + 1 + >>> [v for v in d.keys] + [3, 1, 2] + >>> d.keys.sort() + >>> d.keys + [1, 2, 3] + >>> d = SequenceOrderedDict(((1, 2), (2, 3), (3, 4)), strict=True) + >>> d.keys[::-1] = [1, 2, 3] + >>> d + SequenceOrderedDict([(3, 4), (2, 3), (1, 2)]) + >>> d.keys[:2] + [3, 2] + >>> d.keys[:2] = [1, 3] + Traceback (most recent call last): + KeyError: 'Keylist is not the same as current keylist.' + + >>> d = SequenceOrderedDict(((1, 2), (2, 3), (3, 4))) + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d.values + [2, 3, 4] + >>> d.values() + [2, 3, 4] + >>> d.setvalues((4, 3, 2)) + >>> d + SequenceOrderedDict([(1, 4), (2, 3), (3, 2)]) + >>> d.values[::-1] + [2, 3, 4] + >>> d.values[0] + 4 + >>> d.values[-2] + 3 + >>> del d.values[0] + Traceback (most recent call last): + TypeError: Can't delete items from values + >>> d.values[::2] = [2, 4] + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> 7 in d.values + 0 + >>> len(d.values) + 3 + >>> [val for val in d.values] + [2, 3, 4] + >>> d.values[-1] = 2 + >>> d.values.count(2) + 2 + >>> d.values.index(2) + 0 + >>> d.values[-1] = 7 + >>> d.values + [2, 3, 7] + >>> d.values.reverse() + >>> d.values + [7, 3, 2] + >>> d.values.sort() + >>> d.values + [2, 3, 7] + >>> d.values.append('anything') + Traceback (most recent call last): + TypeError: Can't append items to values + >>> d.values = (1, 2, 3) + >>> d + SequenceOrderedDict([(1, 1), (2, 2), (3, 3)]) + + >>> d = SequenceOrderedDict(((1, 2), (2, 3), (3, 4))) + >>> d + SequenceOrderedDict([(1, 2), (2, 3), (3, 4)]) + >>> d.items() + [(1, 2), (2, 3), (3, 4)] + >>> d.setitems([(3, 4), (2 ,3), (1, 2)]) + >>> d + SequenceOrderedDict([(3, 4), (2, 3), (1, 2)]) + >>> d.items[0] + (3, 4) + >>> d.items[:-1] + [(3, 4), (2, 3)] + >>> d.items[1] = (6, 3) + >>> d.items + [(3, 4), (6, 3), (1, 2)] + >>> d.items[1:2] = [(9, 9)] + >>> d + SequenceOrderedDict([(3, 4), (9, 9), (1, 2)]) + >>> del d.items[1:2] + >>> d + SequenceOrderedDict([(3, 4), (1, 2)]) + >>> (3, 4) in d.items + 1 + >>> (4, 3) in d.items + 0 + >>> len(d.items) + 2 + >>> [v for v in d.items] + [(3, 4), (1, 2)] + >>> d.items.count((3, 4)) + 1 + >>> d.items.index((1, 2)) + 1 + >>> d.items.index((2, 1)) + Traceback (most recent call last): + ValueError: list.index(x): x not in list + >>> d.items.reverse() + >>> d.items + [(1, 2), (3, 4)] + >>> d.items.reverse() + >>> d.items.sort() + >>> d.items + [(1, 2), (3, 4)] + >>> d.items.append((5, 6)) + >>> d.items + [(1, 2), (3, 4), (5, 6)] + >>> d.items.insert(0, (0, 0)) + >>> d.items + [(0, 0), (1, 2), (3, 4), (5, 6)] + >>> d.items.insert(-1, (7, 8)) + >>> d.items + [(0, 0), (1, 2), (3, 4), (7, 8), (5, 6)] + >>> d.items.pop() + (5, 6) + >>> d.items + [(0, 0), (1, 2), (3, 4), (7, 8)] + >>> d.items.remove((1, 2)) + >>> d.items + [(0, 0), (3, 4), (7, 8)] + >>> d.items.extend([(1, 2), (5, 6)]) + >>> d.items + [(0, 0), (3, 4), (7, 8), (1, 2), (5, 6)] + """ + + def __init__(self, init_val=(), strict=True): + OrderedDict.__init__(self, init_val, strict=strict) + self._keys = self.keys + self._values = self.values + self._items = self.items + self.keys = Keys(self) + self.values = Values(self) + self.items = Items(self) + self._att_dict = { + 'keys': self.setkeys, + 'items': self.setitems, + 'values': self.setvalues, + } + + def __setattr__(self, name, value): + """Protect keys, items, and values.""" + if not '_att_dict' in self.__dict__: + object.__setattr__(self, name, value) + else: + try: + fun = self._att_dict[name] + except KeyError: + OrderedDict.__setattr__(self, name, value) + else: + fun(value) + +if __name__ == '__main__': + if INTP_VER < (2, 3): + raise RuntimeError("Tests require Python v.2.3 or later") + # turn off warnings for tests + warnings.filterwarnings('ignore') + # run the code tests in doctest format + import doctest + m = sys.modules.get('__main__') + globs = m.__dict__.copy() + globs.update({ + 'INTP_VER': INTP_VER, + }) + doctest.testmod(m, globs=globs) + diff --git a/lib/git/refs.py b/lib/git/refs.py index a4d7bbb1..5b94ea07 100644 --- a/lib/git/refs.py +++ b/lib/git/refs.py @@ -3,20 +3,23 @@ # # This module is part of GitPython and is released under # the BSD License: http://www.opensource.org/licenses/bsd-license.php -""" -Module containing all ref based objects -""" -from objects.base import Object +""" Module containing all ref based objects """ + +import os +from objects import Object, Commit from objects.utils import get_object_type_by_name from utils import LazyMixin, Iterable class Reference(LazyMixin, Iterable): """ - Represents a named reference to any object + Represents a named reference to any object. Subclasses may apply restrictions though, + i.e. Heads can only point to commits. """ __slots__ = ("repo", "path") + _common_path_default = "refs" + _id_attribute_ = "name" - def __init__(self, repo, path, object = None): + def __init__(self, repo, path): """ Initialize this instance ``repo`` @@ -26,13 +29,12 @@ class Reference(LazyMixin, Iterable): Path relative to the .git/ directory pointing to the ref in question, i.e. refs/heads/master - ``object`` - Object instance, will be retrieved on demand if None """ + if not path.startswith(self._common_path_default): + raise ValueError("Cannot instantiate %s from path %s" % ( self.__class__.__name__, path )) + self.repo = repo self.path = path - if object is not None: - self.object = object def __str__(self): return self.name @@ -63,8 +65,7 @@ class Reference(LazyMixin, Iterable): return '/'.join(tokens[2:]) - @property - def object(self): + def _get_object(self): """ Returns The object our ref currently refers to. Refs can be cached, they will @@ -72,10 +73,44 @@ class Reference(LazyMixin, Iterable): """ # have to be dynamic here as we may be a tag which can point to anything # Our path will be resolved to the hexsha which will be used accordingly - return Object(self.repo, self.path) + return Object.new(self.repo, self.path) + + def _set_object(self, ref): + """ + Set our reference to point to the given ref. It will be converted + to a specific hexsha. + + Note: + TypeChecking is done by the git command + """ + # do it safely by specifying the old value + self.repo.git.update_ref(self.path, ref, self._get_object().sha) + + object = property(_get_object, _set_object, doc="Return the object our ref currently refers to") + + def _set_commit(self, commit): + """ + Set ourselves to point to the given commit. + + Raise + ValueError if commit does not actually point to a commit + """ + self._set_object(commit) + + def _get_commit(self): + """ + Returns + Commit object the reference points to + """ + commit = self.object + if commit.type != "commit": + raise TypeError("Object of reference %s did not point to a commit, but to %r" % (self, commit)) + return commit + + commit = property(_get_commit, _set_commit, doc="Return Commit object the reference points to") @classmethod - def iter_items(cls, repo, common_path = "refs", **kwargs): + def iter_items(cls, repo, common_path = None, **kwargs): """ Find all refs in the repository @@ -84,59 +119,303 @@ class Reference(LazyMixin, Iterable): ``common_path`` Optional keyword argument to the path which is to be shared by all - returned Ref objects - - ``kwargs`` - Additional options given as keyword arguments, will be passed - to git-for-each-ref + returned Ref objects. + Defaults to class specific portion if None assuring that only + refs suitable for the actual class are returned. Returns - git.Ref[] + git.Reference[] - List is sorted by committerdate - The returned objects are compatible to the Ref base, but represent the - actual type, such as Head or Tag + List is lexigraphically sorted + The returned objects represent actual subclasses, such as Head or TagReference """ - - options = {'sort': "committerdate", - 'format': "%(refname)%00%(objectname)%00%(objecttype)%00%(objectsize)"} - - options.update(kwargs) - - output = repo.git.for_each_ref(common_path, **options) - return cls._iter_from_stream(repo, iter(output.splitlines())) - + if common_path is None: + common_path = cls._common_path_default + + rela_paths = set() + + # walk loose refs + # Currently we do not follow links + for root, dirs, files in os.walk(os.path.join(repo.path, common_path)): + for f in files: + abs_path = os.path.join(root, f) + rela_paths.add(abs_path.replace(repo.path + '/', "")) + # END for each file in root directory + # END for each directory to walk + + # read packed refs + packed_refs_path = os.path.join(repo.path, 'packed-refs') + if os.path.isfile(packed_refs_path): + fp = open(packed_refs_path, 'r') + try: + for line in fp.readlines(): + if line.startswith('#'): + continue + # 439689865b9c6e2a0dad61db22a0c9855bacf597 refs/heads/hello + line = line.rstrip() + first_space = line.find(' ') + if first_space == -1: + continue + + rela_path = line[first_space+1:] + if rela_path.startswith(common_path): + rela_paths.add(rela_path) + # END relative path matches common path + # END for each line in packed-refs + finally: + fp.close() + # END packed refs reading + + # return paths in sorted order + for path in sorted(rela_paths): + if path.endswith('/HEAD'): + continue + # END skip remote heads + yield cls.from_path(repo, path) + # END for each sorted relative refpath + + @classmethod - def _iter_from_stream(cls, repo, stream): - """ Parse out ref information into a list of Ref compatible objects - Returns git.Ref[] list of Ref objects """ - heads = [] - - for line in stream: - heads.append(cls._from_string(repo, line)) - - return heads + def from_path(cls, repo, path): + """ + Return + Instance of type Reference, Head, or Tag + depending on the given path + """ + if not path: + raise ValueError("Cannot create Reference from %r" % path) + + for ref_type in (Head, RemoteReference, TagReference, Reference): + try: + return ref_type(repo, path) + except ValueError: + pass + # END exception handling + # END for each type to try + raise ValueError("Could not find reference type suitable to handle path %r" % path) + +class SymbolicReference(object): + """ + Represents a special case of a reference such that this reference is symbolic. + It does not point to a specific commit, but to another Head, which itself + specifies a commit. + + A typical example for a symbolic reference is HEAD. + """ + __slots__ = ("repo", "name") + + def __init__(self, repo, name): + if '/' in name: + # NOTE: Actually they can be looking like ordinary refs. Theoretically we handle this + # case incorrectly + raise ValueError("SymbolicReferences are not located within a directory, got %s" % name) + # END error handling + self.repo = repo + self.name = name + + def __str__(self): + return self.name + + def __repr__(self): + return '<git.%s "%s">' % (self.__class__.__name__, self.name) + + def __eq__(self, other): + return self.name == other.name + + def __ne__(self, other): + return not ( self == other ) + + def __hash__(self): + return hash(self.name) + + def _get_path(self): + return os.path.join(self.repo.path, self.name) + + def _get_commit(self): + """ + Returns: + Commit object we point to, works for detached and non-detached + SymbolicReferences + """ + # we partially reimplement it to prevent unnecessary file access + fp = open(self._get_path(), 'r') + value = fp.read().rstrip() + fp.close() + tokens = value.split(" ") + + # it is a detached reference + if self.repo.re_hexsha_only.match(tokens[0]): + return Commit(self.repo, tokens[0]) + + # must be a head ! Git does not allow symbol refs to other things than heads + # Otherwise it would have detached it + if tokens[0] != "ref:": + raise ValueError("Failed to parse symbolic refernce: wanted 'ref: <hexsha>', got %r" % value) + return Head(self.repo, tokens[1]).commit + + def _set_commit(self, commit): + """ + Set our commit, possibly dereference our symbolic reference first. + """ + if self.is_detached: + return self._set_reference(commit) + + # set the commit on our reference + self._get_reference().commit = commit + + commit = property(_get_commit, _set_commit, doc="Query or set commits directly") + + def _get_reference(self): + """ + Returns + Reference Object we point to + """ + fp = open(self._get_path(), 'r') + try: + tokens = fp.readline().rstrip().split(' ') + if tokens[0] != 'ref:': + raise TypeError("%s is a detached symbolic reference as it points to %r" % (self, tokens[0])) + return Reference.from_path(self.repo, tokens[1]) + finally: + fp.close() + + def _set_reference(self, ref): + """ + Set ourselves to the given ref. It will stay a symbol if the ref is a Head. + Otherwise we try to get a commit from it using our interface. + + Strings are allowed but will be checked to be sure we have a commit + """ + write_value = None + if isinstance(ref, Head): + write_value = "ref: %s" % ref.path + elif isinstance(ref, Commit): + write_value = ref.sha + else: + try: + write_value = ref.commit.sha + except AttributeError: + sha = str(ref) + try: + obj = Object.new(self.repo, sha) + if obj.type != "commit": + raise TypeError("Invalid object type behind sha: %s" % sha) + write_value = obj.sha + except Exception: + raise ValueError("Could not extract object from %s" % ref) + # END end try string + # END try commit attribute + + # if we are writing a ref, use symbolic ref to get the reflog and more + # checking + # Otherwise we detach it and have to do it manually + if write_value.startswith('ref:'): + self.repo.git.symbolic_ref(self.name, write_value[5:]) + return + # END non-detached handling + + fp = open(self._get_path(), "w") + try: + fp.write(write_value) + finally: + fp.close() + # END writing + + reference = property(_get_reference, _set_reference, doc="Returns the Reference we point to") + + # alias + ref = reference + + @property + def is_detached(self): + """ + Returns + True if we are a detached reference, hence we point to a specific commit + instead to another reference + """ + try: + self.reference + return False + except TypeError: + return True + @classmethod - def _from_string(cls, repo, line): - """ Create a new Ref instance from the given string. - Format - name: [a-zA-Z_/]+ - <null byte> - id: [0-9A-Fa-f]{40} - Returns git.Head """ - full_path, hexsha, type_name, object_size = line.split("\x00") - - # No, we keep the object dynamic by allowing it to be retrieved by - # our path on demand - due to perstent commands it is fast. - # This reduces the risk that the object does not match - # the changed ref anymore in case it changes in the meanwhile - return cls(repo, full_path) - - # obj = get_object_type_by_name(type_name)(repo, hexsha) - # obj.size = object_size - # return cls(repo, full_path, obj) + def from_path(cls, repo, path): + """ + Return + Instance of SymbolicReference or HEAD + depending on the given path + """ + if not path: + raise ValueError("Cannot create Symbolic Reference from %r" % path) + + if path == 'HEAD': + return HEAD(repo, path) + + if '/' not in path: + return SymbolicReference(repo, path) + + raise ValueError("Could not find symbolic reference type suitable to handle path %r" % path) + +class HEAD(SymbolicReference): + """ + Special case of a Symbolic Reference as it represents the repository's + HEAD reference. + """ + _HEAD_NAME = 'HEAD' + __slots__ = tuple() + + def __init__(self, repo, name=_HEAD_NAME): + if name != self._HEAD_NAME: + raise ValueError("HEAD instance must point to %r, got %r" % (self._HEAD_NAME, name)) + super(HEAD, self).__init__(repo, name) + + + def reset(self, commit='HEAD', index=True, working_tree = False, + paths=None, **kwargs): + """ + Reset our HEAD to the given commit optionally synchronizing + the index and working tree. The reference we refer to will be set to + commit as well. + + ``commit`` + Commit object, Reference Object or string identifying a revision we + should reset HEAD to. + + ``index`` + If True, the index will be set to match the given commit. Otherwise + it will not be touched. + ``working_tree`` + If True, the working tree will be forcefully adjusted to match the given + commit, possibly overwriting uncommitted changes without warning. + If working_tree is True, index must be true as well + + ``paths`` + Single path or list of paths relative to the git root directory + that are to be reset. This allow to partially reset individual files. + + ``kwargs`` + Additional arguments passed to git-reset. + + Returns + self + """ + mode = "--soft" + if index: + mode = "--mixed" + + if working_tree: + mode = "--hard" + if not index: + raise ValueError( "Cannot reset the working tree if the index is not reset as well") + # END working tree handling + + self.repo.git.reset(mode, commit, paths, **kwargs) + + return self + class Head(Reference): """ @@ -154,49 +433,141 @@ class Head(Reference): >>> head.commit <git.Commit "1c09f116cbc2cb4100fb6935bb162daa4723f455"> - >>> head.commit.id + >>> head.commit.sha '1c09f116cbc2cb4100fb6935bb162daa4723f455' """ - - @property - def commit(self): + _common_path_default = "refs/heads" + + @classmethod + def create(cls, repo, path, commit='HEAD', force=False, **kwargs ): """ + Create a new head. + ``repo`` + Repository to create the head in + + ``path`` + The name or path of the head, i.e. 'new_branch' or + feature/feature1. The prefix refs/heads is implied. + + ``commit`` + Commit to which the new head should point, defaults to the + current HEAD + + ``force`` + if True, force creation even if branch with that name already exists. + + ``**kwargs`` + Additional keyword arguments to be passed to git-branch, i.e. + track, no-track, l + Returns - Commit object the head points to + Newly created Head + + Note + This does not alter the current HEAD, index or Working Tree """ - return self.object + if cls is not Head: + raise TypeError("Only Heads can be created explicitly, not objects of type %s" % cls.__name__) + + args = ( path, commit ) + if force: + kwargs['f'] = True + + repo.git.branch(*args, **kwargs) + return cls(repo, "%s/%s" % ( cls._common_path_default, path)) + @classmethod - def iter_items(cls, repo, common_path = "refs/heads", **kwargs): + def delete(cls, repo, *heads, **kwargs): + """ + Delete the given heads + + ``force`` + If True, the heads will be deleted even if they are not yet merged into + the main development stream. + Default False """ + force = kwargs.get("force", False) + flag = "-d" + if force: + flag = "-D" + repo.git.branch(flag, *heads) + + + def rename(self, new_path, force=False): + """ + Rename self to a new path + + ``new_path`` + Either a simple name or a path, i.e. new_name or features/new_name. + The prefix refs/heads is implied + + ``force`` + If True, the rename will succeed even if a head with the target name + already exists. + Returns - Iterator yielding Head items + self + """ + flag = "-m" + if force: + flag = "-M" - For more documentation, please refer to git.base.Ref.list_items + self.repo.git.branch(flag, self, new_path) + self.path = "%s/%s" % (self._common_path_default, new_path) + return self + + def checkout(self, force=False, **kwargs): """ - return super(Head,cls).iter_items(repo, common_path, **kwargs) - - def __repr__(self): - return '<git.Head "%s">' % self.name + Checkout this head by setting the HEAD to this reference, by updating the index + to reflect the tree we point to and by updating the working tree to reflect + the latest index. + + The command will fail if changed working tree files would be overwritten. + ``force`` + If True, changes to the index and the working tree will be discarded. + If False, GitCommandError will be raised in that situation. + + ``**kwargs`` + Additional keyword arguments to be passed to git checkout, i.e. + b='new_branch' to create a new branch at the given spot. + + Returns + The active branch after the checkout operation, usually self unless + a new branch has been created. + + Note + By default it is only allowed to checkout heads - everything else + will leave the HEAD detached which is allowed and possible, but remains + a special state that some tools might not be able to handle. + """ + args = list() + kwargs['f'] = force + if kwargs['f'] == False: + kwargs.pop('f') + + self.repo.git.checkout(self, **kwargs) + return self.repo.active_branch -class TagRef(Reference): +class TagReference(Reference): """ Class representing a lightweight tag reference which either points to a commit - or to a tag object. In the latter case additional information, like the signature - or the tag-creator, is available. + ,a tag object or any other object. In the latter case additional information, + like the signature or the tag-creator, is available. This tag object will always point to a commit object, but may carray additional information in a tag object:: - tagref = TagRef.list_items(repo)[0] + tagref = TagReference.list_items(repo)[0] print tagref.commit.message if tagref.tag is not None: print tagref.tag.message """ __slots__ = tuple() + _common_path_default = "refs/tags" @property def commit(self): @@ -222,17 +593,93 @@ class TagRef(Reference): if self.object.type == "tag": return self.object return None - + @classmethod - def iter_items(cls, repo, common_path = "refs/tags", **kwargs): + def create(cls, repo, path, ref='HEAD', message=None, force=False, **kwargs): """ - Returns - Iterator yielding commit items + Create a new tag reference. + + ``path`` + The name of the tag, i.e. 1.0 or releases/1.0. + The prefix refs/tags is implied + + ``ref`` + A reference to the object you want to tag. It can be a commit, tree or + blob. - For more documentation, please refer to git.base.Ref.list_items + ``message`` + If not None, the message will be used in your tag object. This will also + create an additional tag object that allows to obtain that information, i.e.:: + tagref.tag.message + + ``force`` + If True, to force creation of a tag even though that tag already exists. + + ``**kwargs`` + Additional keyword arguments to be passed to git-tag + + Returns + A new TagReference """ - return super(TagRef,cls).iter_items(repo, common_path, **kwargs) + args = ( path, ref ) + if message: + kwargs['m'] = message + if force: + kwargs['f'] = True + repo.git.tag(*args, **kwargs) + return TagReference(repo, "%s/%s" % (cls._common_path_default, path)) + + @classmethod + def delete(cls, repo, *tags): + """ + Delete the given existing tag or tags + """ + repo.git.tag("-d", *tags) + + + + # provide an alias -Tag = TagRef +Tag = TagReference + +class RemoteReference(Head): + """ + Represents a reference pointing to a remote head. + """ + _common_path_default = "refs/remotes" + + @property + def remote_name(self): + """ + Returns + Name of the remote we are a reference of, such as 'origin' for a reference + named 'origin/master' + """ + tokens = self.path.split('/') + # /refs/remotes/<remote name>/<branch_name> + return tokens[2] + + @property + def remote_head(self): + """ + Returns + Name of the remote head itself, i.e. master. + + NOTE: The returned name is usually not qualified enough to uniquely identify + a branch + """ + tokens = self.path.split('/') + return '/'.join(tokens[3:]) + + @classmethod + def delete(cls, repo, *remotes, **kwargs): + """ + Delete the given remote references. + + Note + kwargs are given for compatability with the base class method as we + should not narrow the signature. + """ + repo.git.branch("-d", "-r", *remotes) diff --git a/lib/git/remote.py b/lib/git/remote.py new file mode 100644 index 00000000..9141fc3b --- /dev/null +++ b/lib/git/remote.py @@ -0,0 +1,718 @@ +# remote.py +# Copyright (C) 2008, 2009 Michael Trier (mtrier@gmail.com) and contributors +# +# This module is part of GitPython and is released under +# the BSD License: http://www.opensource.org/licenses/bsd-license.php +"""Module implementing a remote object allowing easy access to git remotes""" + +from errors import GitCommandError +from git.utils import LazyMixin, Iterable, IterableList +from objects import Commit +from refs import Reference, RemoteReference, SymbolicReference, TagReference + +import re +import os + +class _SectionConstraint(object): + """ + Constrains a ConfigParser to only option commands which are constrained to + always use the section we have been initialized with. + + It supports all ConfigParser methods that operate on an option + """ + __slots__ = ("_config", "_section_name") + _valid_attrs_ = ("get", "set", "getint", "getfloat", "getboolean", "has_option") + + def __init__(self, config, section): + self._config = config + self._section_name = section + + def __getattr__(self, attr): + if attr in self._valid_attrs_: + return lambda *args: self._call_config(attr, *args) + return super(_SectionConstraint,self).__getattribute__(attr) + + def _call_config(self, method, *args): + """Call the configuration at the given method which must take a section name + as first argument""" + return getattr(self._config, method)(self._section_name, *args) + + +class PushProgress(object): + """ + Handler providing an interface to parse progress information emitted by git-push + and to dispatch callbacks allowing subclasses to react to the progress. + """ + BEGIN, END, COUNTING, COMPRESSING, WRITING = [ 1 << x for x in range(5) ] + STAGE_MASK = BEGIN|END + OP_MASK = COUNTING|COMPRESSING|WRITING + + __slots__ = ("_cur_line", "_seen_ops") + re_op_absolute = re.compile("([\w\s]+):\s+()(\d+)()(, done\.)?\s*") + re_op_relative = re.compile("([\w\s]+):\s+(\d+)% \((\d+)/(\d+)\)(,.* done\.)?$") + + def __init__(self): + self._seen_ops = list() + + def _parse_progress_line(self, line): + """ + Parse progress information from the given line as retrieved by git-push + """ + # handle + # Counting objects: 4, done. + # Compressing objects: 50% (1/2) \rCompressing objects: 100% (2/2) \rCompressing objects: 100% (2/2), done. + self._cur_line = line + sub_lines = line.split('\r') + for sline in sub_lines: + sline = sline.rstrip() + + cur_count, max_count = None, None + match = self.re_op_relative.match(sline) + if match is None: + match = self.re_op_absolute.match(sline) + + if not match: + self.line_dropped(sline) + continue + # END could not get match + + op_code = 0 + op_name, percent, cur_count, max_count, done = match.groups() + # get operation id + if op_name == "Counting objects": + op_code |= self.COUNTING + elif op_name == "Compressing objects": + op_code |= self.COMPRESSING + elif op_name == "Writing objects": + op_code |= self.WRITING + else: + raise ValueError("Operation name %r unknown" % op_name) + + # figure out stage + if op_code not in self._seen_ops: + self._seen_ops.append(op_code) + op_code |= self.BEGIN + # END begin opcode + + message = '' + if done is not None and 'done.' in done: + op_code |= self.END + message = done.replace( ", done.", "")[2:] + # END end flag handling + + self.update(op_code, cur_count, max_count, message) + + # END for each sub line + + def line_dropped(self, line): + """ + Called whenever a line could not be understood and was therefore dropped. + """ + pass + + def update(self, op_code, cur_count, max_count=None, message=''): + """ + Called whenever the progress changes + + ``op_code`` + Integer allowing to be compared against Operation IDs and stage IDs. + + Stage IDs are BEGIN and END. BEGIN will only be set once for each Operation + ID as well as END. It may be that BEGIN and END are set at once in case only + one progress message was emitted due to the speed of the operation. + Between BEGIN and END, none of these flags will be set + + Operation IDs are all held within the OP_MASK. Only one Operation ID will + be active per call. + + ``cur_count`` + Current absolute count of items + + ``max_count`` + The maximum count of items we expect. It may be None in case there is + no maximum number of items or if it is (yet) unknown. + + ``message`` + In case of the 'WRITING' operation, it contains the amount of bytes + transferred. It may possibly be used for other purposes as well. + + You may read the contents of the current line in self._cur_line + """ + pass + + +class PushInfo(object): + """ + Carries information about the result of a push operation of a single head:: + + info = remote.push()[0] + info.flags # bitflags providing more information about the result + info.local_ref # Reference pointing to the local reference that was pushed + # It is None if the ref was deleted. + info.remote_ref_string # path to the remote reference located on the remote side + info.remote_ref # Remote Reference on the local side corresponding to + # the remote_ref_string. It can be a TagReference as well. + info.old_commit # commit at which the remote_ref was standing before we pushed + # it to local_ref.commit. Will be None if an error was indicated + + """ + __slots__ = ('local_ref', 'remote_ref_string', 'flags', 'old_commit', '_remote') + + NEW_TAG, NEW_HEAD, NO_MATCH, REJECTED, REMOTE_REJECTED, REMOTE_FAILURE, DELETED, \ + FORCED_UPDATE, FAST_FORWARD, UP_TO_DATE, ERROR = [ 1 << x for x in range(11) ] + + _flag_map = { 'X' : NO_MATCH, '-' : DELETED, '*' : 0, + '+' : FORCED_UPDATE, ' ' : FAST_FORWARD, + '=' : UP_TO_DATE, '!' : ERROR } + + def __init__(self, flags, local_ref, remote_ref_string, remote, old_commit=None): + """ + Initialize a new instance + """ + self.flags = flags + self.local_ref = local_ref + self.remote_ref_string = remote_ref_string + self._remote = remote + self.old_commit = old_commit + + @property + def remote_ref(self): + """ + Returns + Remote Reference or TagReference in the local repository corresponding + to the remote_ref_string kept in this instance. + """ + # translate heads to a local remote, tags stay as they are + if self.remote_ref_string.startswith("refs/tags"): + return TagReference(self._remote.repo, self.remote_ref_string) + elif self.remote_ref_string.startswith("refs/heads"): + remote_ref = Reference(self._remote.repo, self.remote_ref_string) + return RemoteReference(self._remote.repo, "refs/remotes/%s/%s" % (str(self._remote), remote_ref.name)) + else: + raise ValueError("Could not handle remote ref: %r" % self.remote_ref_string) + # END + + @classmethod + def _from_line(cls, remote, line): + """ + Create a new PushInfo instance as parsed from line which is expected to be like + c refs/heads/master:refs/heads/master 05d2687..1d0568e + """ + control_character, from_to, summary = line.split('\t', 3) + flags = 0 + + # control character handling + try: + flags |= cls._flag_map[ control_character ] + except KeyError: + raise ValueError("Control Character %r unknown as parsed from line %r" % (control_character, line)) + # END handle control character + + # from_to handling + from_ref_string, to_ref_string = from_to.split(':') + if flags & cls.DELETED: + from_ref = None + else: + from_ref = Reference.from_path(remote.repo, from_ref_string) + + # commit handling, could be message or commit info + old_commit = None + if summary.startswith('['): + if "[rejected]" in summary: + flags |= cls.REJECTED + elif "[remote rejected]" in summary: + flags |= cls.REMOTE_REJECTED + elif "[remote failure]" in summary: + flags |= cls.REMOTE_FAILURE + elif "[no match]" in summary: + flags |= cls.ERROR + elif "[new tag]" in summary: + flags |= cls.NEW_TAG + elif "[new branch]" in summary: + flags |= cls.NEW_HEAD + # uptodate encoded in control character + else: + # fast-forward or forced update - was encoded in control character, + # but we parse the old and new commit + split_token = "..." + if control_character == " ": + split_token = ".." + old_sha, new_sha = summary.split(' ')[0].split(split_token) + old_commit = Commit(remote.repo, old_sha) + # END message handling + + return PushInfo(flags, from_ref, to_ref_string, remote, old_commit) + + +class FetchInfo(object): + """ + Carries information about the results of a fetch operation of a single head:: + + info = remote.fetch()[0] + info.ref # Symbolic Reference or RemoteReference to the changed + # remote head or FETCH_HEAD + info.flags # additional flags to be & with enumeration members, + # i.e. info.flags & info.REJECTED + # is 0 if ref is SymbolicReference + info.note # additional notes given by git-fetch intended for the user + info.commit_before_forced_update # if info.flags & info.FORCED_UPDATE, + # field is set to the previous location of ref, otherwise None + """ + __slots__ = ('ref','commit_before_forced_update', 'flags', 'note') + + NEW_TAG, NEW_HEAD, HEAD_UPTODATE, TAG_UPDATE, REJECTED, FORCED_UPDATE, \ + FAST_FORWARD, ERROR = [ 1 << x for x in range(8) ] + + # %c %-*s %-*s -> %s (%s) + re_fetch_result = re.compile("^\s*(.) (\[?[\w\s\.]+\]?)\s+(.+) -> ([/\w_\.-]+)( \(.*\)?$)?") + + _flag_map = { '!' : ERROR, '+' : FORCED_UPDATE, '-' : TAG_UPDATE, '*' : 0, + '=' : HEAD_UPTODATE, ' ' : FAST_FORWARD } + + def __init__(self, ref, flags, note = '', old_commit = None): + """ + Initialize a new instance + """ + self.ref = ref + self.flags = flags + self.note = note + self.commit_before_forced_update = old_commit + + def __str__(self): + return self.name + + @property + def name(self): + """ + Returns + Name of our remote ref + """ + return self.ref.name + + @property + def commit(self): + """ + Returns + Commit of our remote ref + """ + return self.ref.commit + + @classmethod + def _from_line(cls, repo, line, fetch_line): + """ + Parse information from the given line as returned by git-fetch -v + and return a new FetchInfo object representing this information. + + We can handle a line as follows + "%c %-*s %-*s -> %s%s" + + Where c is either ' ', !, +, -, *, or = + ! means error + + means success forcing update + - means a tag was updated + * means birth of new branch or tag + = means the head was up to date ( and not moved ) + ' ' means a fast-forward + + fetch line is the corresponding line from FETCH_HEAD, like + acb0fa8b94ef421ad60c8507b634759a472cd56c not-for-merge branch '0.1.7RC' of /tmp/tmpya0vairemote_repo + """ + match = cls.re_fetch_result.match(line) + if match is None: + raise ValueError("Failed to parse line: %r" % line) + + # parse lines + control_character, operation, local_remote_ref, remote_local_ref, note = match.groups() + try: + new_hex_sha, fetch_operation, fetch_note = fetch_line.split("\t") + ref_type_name, fetch_note = fetch_note.split(' ', 1) + except ValueError: # unpack error + raise ValueError("Failed to parse FETCH__HEAD line: %r" % fetch_line) + + # handle FETCH_HEAD and figure out ref type + # If we do not specify a target branch like master:refs/remotes/origin/master, + # the fetch result is stored in FETCH_HEAD which destroys the rule we usually + # have. In that case we use a symbolic reference which is detached + ref_type = None + if remote_local_ref == "FETCH_HEAD": + ref_type = SymbolicReference + elif ref_type_name == "branch": + ref_type = RemoteReference + elif ref_type_name == "tag": + ref_type = TagReference + else: + raise TypeError("Cannot handle reference type: %r" % ref_type_name) + + # create ref instance + if ref_type is SymbolicReference: + remote_local_ref = ref_type(repo, "FETCH_HEAD") + else: + remote_local_ref = Reference.from_path(repo, os.path.join(ref_type._common_path_default, remote_local_ref.strip())) + # END create ref instance + + note = ( note and note.strip() ) or '' + + # parse flags from control_character + flags = 0 + try: + flags |= cls._flag_map[control_character] + except KeyError: + raise ValueError("Control character %r unknown as parsed from line %r" % (control_character, line)) + # END control char exception hanlding + + # parse operation string for more info - makes no sense for symbolic refs + old_commit = None + if isinstance(remote_local_ref, Reference): + if 'rejected' in operation: + flags |= cls.REJECTED + if 'new tag' in operation: + flags |= cls.NEW_TAG + if 'new branch' in operation: + flags |= cls.NEW_HEAD + if '...' in operation: + old_commit = Commit(repo, operation.split('...')[0]) + # END handle refspec + # END reference flag handling + + return cls(remote_local_ref, flags, note, old_commit) + + +class Remote(LazyMixin, Iterable): + """ + Provides easy read and write access to a git remote. + + Everything not part of this interface is considered an option for the current + remote, allowing constructs like remote.pushurl to query the pushurl. + + NOTE: When querying configuration, the configuration accessor will be cached + to speed up subsequent accesses. + """ + + __slots__ = ( "repo", "name", "_config_reader" ) + _id_attribute_ = "name" + + def __init__(self, repo, name): + """ + Initialize a remote instance + + ``repo`` + The repository we are a remote of + + ``name`` + the name of the remote, i.e. 'origin' + """ + self.repo = repo + self.name = name + + def __getattr__(self, attr): + """ + Allows to call this instance like + remote.special( *args, **kwargs) to call git-remote special self.name + """ + if attr == "_config_reader": + return super(Remote, self).__getattr__(attr) + + return self._config_reader.get(attr) + + def _config_section_name(self): + return 'remote "%s"' % self.name + + def _set_cache_(self, attr): + if attr == "_config_reader": + self._config_reader = _SectionConstraint(self.repo.config_reader(), self._config_section_name()) + else: + super(Remote, self)._set_cache_(attr) + + + def __str__(self): + return self.name + + def __repr__(self): + return '<git.%s "%s">' % (self.__class__.__name__, self.name) + + def __eq__(self, other): + return self.name == other.name + + def __ne__(self, other): + return not ( self == other ) + + def __hash__(self): + return hash(self.name) + + @classmethod + def iter_items(cls, repo): + """ + Returns + Iterator yielding Remote objects of the given repository + """ + for section in repo.config_reader("repository").sections(): + if not section.startswith('remote'): + continue + lbound = section.find('"') + rbound = section.rfind('"') + if lbound == -1 or rbound == -1: + raise ValueError("Remote-Section has invalid format: %r" % section) + yield Remote(repo, section[lbound+1:rbound]) + # END for each configuration section + + @property + def refs(self): + """ + Returns + IterableList of RemoteReference objects. It is prefixed, allowing + you to omit the remote path portion, i.e.:: + remote.refs.master # yields RemoteReference('/refs/remotes/origin/master') + """ + out_refs = IterableList(RemoteReference._id_attribute_, "%s/" % self.name) + for ref in RemoteReference.list_items(self.repo): + if ref.remote_name == self.name: + out_refs.append(ref) + # END if names match + # END for each ref + assert out_refs, "Remote %s did not have any references" % self.name + return out_refs + + @property + def stale_refs(self): + """ + Returns + IterableList RemoteReference objects that do not have a corresponding + head in the remote reference anymore as they have been deleted on the + remote side, but are still available locally. + + The IterableList is prefixed, hence the 'origin' must be omitted. See + 'refs' property for an example. + """ + out_refs = IterableList(RemoteReference._id_attribute_, "%s/" % self.name) + for line in self.repo.git.remote("prune", "--dry-run", self).splitlines()[2:]: + # expecting + # * [would prune] origin/new_branch + token = " * [would prune] " + if not line.startswith(token): + raise ValueError("Could not parse git-remote prune result: %r" % line) + fqhn = "%s/%s" % (RemoteReference._common_path_default,line.replace(token, "")) + out_refs.append(RemoteReference(self.repo, fqhn)) + # END for each line + return out_refs + + @classmethod + def create(cls, repo, name, url, **kwargs): + """ + Create a new remote to the given repository + ``repo`` + Repository instance that is to receive the new remote + + ``name`` + Desired name of the remote + + ``url`` + URL which corresponds to the remote's name + + ``**kwargs`` + Additional arguments to be passed to the git-remote add command + + Returns + New Remote instance + + Raise + GitCommandError in case an origin with that name already exists + """ + repo.git.remote( "add", name, url, **kwargs ) + return cls(repo, name) + + # add is an alias + add = create + + @classmethod + def remove(cls, repo, name ): + """ + Remove the remote with the given name + """ + repo.git.remote("rm", name) + + # alias + rm = remove + + def rename(self, new_name): + """ + Rename self to the given new_name + + Returns + self + """ + if self.name == new_name: + return self + + self.repo.git.remote("rename", self.name, new_name) + self.name = new_name + del(self._config_reader) # it contains cached values, section names are different now + return self + + def update(self, **kwargs): + """ + Fetch all changes for this remote, including new branches which will + be forced in ( in case your local remote branch is not part the new remote branches + ancestry anymore ). + + ``kwargs`` + Additional arguments passed to git-remote update + + Returns + self + """ + self.repo.git.remote("update", self.name) + return self + + def _get_fetch_info_from_stderr(self, stderr): + # skip first line as it is some remote info we are not interested in + output = IterableList('name') + err_info = stderr.splitlines()[1:] + + # read head information + fp = open(os.path.join(self.repo.path, 'FETCH_HEAD'),'r') + fetch_head_info = fp.readlines() + fp.close() + + output.extend(FetchInfo._from_line(self.repo, err_line, fetch_line) + for err_line,fetch_line in zip(err_info, fetch_head_info)) + return output + + def _get_push_info(self, proc, progress): + # read progress information from stderr + # we hope stdout can hold all the data, it should ... + # read the lines manually as it will use carriage returns between the messages + # to override the previous one. This is why we read the bytes manually + line_so_far = '' + while True: + char = proc.stderr.read(1) + if not char: + break + + if char in ('\r', '\n'): + progress._parse_progress_line(line_so_far) + line_so_far = '' + else: + line_so_far += char + # END process parsed line + # END for each progress line + + output = IterableList('name') + for line in proc.stdout.readlines(): + try: + output.append(PushInfo._from_line(self, line)) + except ValueError: + # if an error happens, additional info is given which we cannot parse + pass + # END exception handling + # END for each line + try: + proc.wait() + except GitCommandError: + # if a push has rejected items, the command has non-zero return status + pass + # END exception handling + return output + + + def fetch(self, refspec=None, **kwargs): + """ + Fetch the latest changes for this remote + + ``refspec`` + A "refspec" is used by fetch and push to describe the mapping + between remote ref and local ref. They are combined with a colon in + the format <src>:<dst>, preceded by an optional plus sign, +. + For example: git fetch $URL refs/heads/master:refs/heads/origin means + "grab the master branch head from the $URL and store it as my origin + branch head". And git push $URL refs/heads/master:refs/heads/to-upstream + means "publish my master branch head as to-upstream branch at $URL". + See also git-push(1). + + Taken from the git manual + + ``**kwargs`` + Additional arguments to be passed to git-fetch + + Returns + IterableList(FetchInfo, ...) list of FetchInfo instances providing detailed + information about the fetch results + + Note + As fetch does not provide progress information to non-ttys, we cannot make + it available here unfortunately as in the 'push' method. + """ + status, stdout, stderr = self.repo.git.fetch(self, refspec, with_extended_output=True, v=True, **kwargs) + return self._get_fetch_info_from_stderr(stderr) + + def pull(self, refspec=None, **kwargs): + """ + Pull changes from the given branch, being the same as a fetch followed + by a merge of branch with your local branch. + + ``refspec`` + see 'fetch' method + + ``**kwargs`` + Additional arguments to be passed to git-pull + + Returns + Please see 'fetch' method + """ + status, stdout, stderr = self.repo.git.pull(self, refspec, with_extended_output=True, v=True, **kwargs) + return self._get_fetch_info_from_stderr(stderr) + + def push(self, refspec=None, progress=None, **kwargs): + """ + Push changes from source branch in refspec to target branch in refspec. + + ``refspec`` + see 'fetch' method + + ``progress`` + Instance of type PushProgress allowing the caller to receive + progress information until the method returns. + If None, progress information will be discarded + + ``**kwargs`` + Additional arguments to be passed to git-push + + Returns + IterableList(PushInfo, ...) iterable list of PushInfo instances, each + one informing about an individual head which had been updated on the remote + side. + If the push contains rejected heads, these will have the PushInfo.ERROR bit set + in their flags. + If the operation fails completely, the length of the returned IterableList will + be null. + """ + proc = self.repo.git.push(self, refspec, porcelain=True, as_process=True, **kwargs) + return self._get_push_info(proc, progress or PushProgress()) + + @property + def config_reader(self): + """ + Returns + GitConfigParser compatible object able to read options for only our remote. + Hence you may simple type config.get("pushurl") to obtain the information + """ + return self._config_reader + + @property + def config_writer(self): + """ + Return + GitConfigParser compatible object able to write options for this remote. + + Note + You can only own one writer at a time - delete it to release the + configuration file and make it useable by others. + + To assure consistent results, you should only query options through the + writer. Once you are done writing, you are free to use the config reader + once again. + """ + writer = self.repo.config_writer() + + # clear our cache to assure we re-read the possibly changed configuration + del(self._config_reader) + return _SectionConstraint(writer, self._config_section_name()) diff --git a/lib/git/repo.py b/lib/git/repo.py index 554c10cb..6d388633 100644 --- a/lib/git/repo.py +++ b/lib/git/repo.py @@ -5,17 +5,36 @@ # the BSD License: http://www.opensource.org/licenses/bsd-license.php import os +import sys import re import gzip import StringIO from errors import InvalidGitRepositoryError, NoSuchPathError -from utils import touch, is_git_dir from cmd import Git from actor import Actor from refs import * +from index import IndexFile from objects import * - +from config import GitConfigParser +from remote import Remote + +def touch(filename): + fp = open(filename, "a") + fp.close() + +def is_git_dir(d): + """ This is taken from the git setup.c:is_git_directory + function.""" + + if os.path.isdir(d) and \ + os.path.isdir(os.path.join(d, 'objects')) and \ + os.path.isdir(os.path.join(d, 'refs')): + headref = os.path.join(d, 'HEAD') + return os.path.isfile(headref) or \ + (os.path.islink(headref) and + os.readlink(headref).startswith('refs')) + return False class Repo(object): """ @@ -24,12 +43,17 @@ class Repo(object): the log. """ DAEMON_EXPORT_FILE = 'git-daemon-export-ok' + __slots__ = ( "wd", "path", "_bare", "git" ) # precompiled regex re_whitespace = re.compile(r'\s+') re_hexsha_only = re.compile('^[0-9A-Fa-f]{40}$') re_author_committer_start = re.compile(r'^(author|committer)') re_tab_full_line = re.compile(r'^\t(.*)$') + + # invariants + # represents the configuration level of a configuration file + config_level = ("system", "global", "repository") def __init__(self, path=None): """ @@ -57,22 +81,30 @@ class Repo(object): self.path = None curpath = epath + + # walk up the path to find the .git dir while curpath: if is_git_dir(curpath): - self.bare = True self.path = curpath self.wd = curpath break gitpath = os.path.join(curpath, '.git') if is_git_dir(gitpath): - self.bare = False self.path = gitpath self.wd = curpath break curpath, dummy = os.path.split(curpath) if not dummy: break - + # END while curpath + + self._bare = False + try: + self._bare = self.config_reader("repository").getboolean('core','bare') + except Exception: + # lets not assume the option exists, although it should + pass + if self.path is None: raise InvalidGitRepositoryError(epath) @@ -91,6 +123,15 @@ class Repo(object): doc="the project's description") del _get_description del _set_description + + + @property + def bare(self): + """ + Returns + True if the repository is bare + """ + return self._bare @property def heads(self): @@ -99,12 +140,58 @@ class Repo(object): this repo Returns - ``git.Head[]`` + ``git.IterableList(Head, ...)`` """ return Head.list_items(self) + + @property + def refs(self): + """ + A list of Reference objects representing tags, heads and remote references. + + Returns + IterableList(Reference, ...) + """ + return Reference.list_items(self) # alias heads branches = heads + + @property + def index(self): + """ + Returns + IndexFile representing this repository's index. + """ + return IndexFile(self) + + @property + def head(self): + """ + Return + HEAD Object pointing to the current head reference + """ + return HEAD(self,'HEAD') + + @property + def remotes(self): + """ + A list of Remote objects allowing to access and manipulate remotes + + Returns + ``git.IterableList(Remote, ...)`` + """ + return Remote.list_items(self) + + def remote(self, name='origin'): + """ + Return + Remote with the specified name + + Raise + ValueError if no remote with such a name exists + """ + return Remote(self, name) @property def tags(self): @@ -112,9 +199,129 @@ class Repo(object): A list of ``Tag`` objects that are available in this repo Returns - ``git.Tag[]`` + ``git.IterableList(TagReference, ...)`` + """ + return TagReference.list_items(self) + + def tag(self,path): + """ + Return + TagReference Object, reference pointing to a Commit or Tag + + ``path`` + path to the tag reference, i.e. 0.1.5 or tags/0.1.5 + """ + return TagReference(self, path) + + def create_head(self, path, commit='HEAD', force=False, **kwargs ): + """ + Create a new head within the repository. + + For more documentation, please see the Head.create method. + + Returns + newly created Head Reference + """ + return Head.create(self, path, commit, force, **kwargs) + + def delete_head(self, *heads, **kwargs): + """ + Delete the given heads + + ``kwargs`` + Additional keyword arguments to be passed to git-branch + """ + return Head.delete(self, *heads, **kwargs) + + def create_tag(self, path, ref='HEAD', message=None, force=False, **kwargs): + """ + Create a new tag reference. + + For more documentation, please see the TagReference.create method. + + Returns + TagReference object + """ + return TagReference.create(self, path, ref, message, force, **kwargs) + + def delete_tag(self, *tags): + """ + Delete the given tag references """ - return Tag.list_items(self) + return TagReference.delete(self, *tags) + + def create_remote(self, name, url, **kwargs): + """ + Create a new remote. + + For more information, please see the documentation of the Remote.create + methods + + Returns + Remote reference + """ + return Remote.create(self, name, url, **kwargs) + + def delete_remote(self, remote): + """ + Delete the given remote. + """ + return Remote.remove(self, remote) + + def _get_config_path(self, config_level ): + # we do not support an absolute path of the gitconfig on windows , + # use the global config instead + if sys.platform == "win32" and config_level == "system": + config_level = "global" + + if config_level == "system": + return "/etc/gitconfig" + elif config_level == "global": + return os.path.expanduser("~/.gitconfig") + elif config_level == "repository": + return "%s/config" % self.path + + raise ValueError( "Invalid configuration level: %r" % config_level ) + + def config_reader(self, config_level=None): + """ + Returns + GitConfigParser allowing to read the full git configuration, but not to write it + + The configuration will include values from the system, user and repository + configuration files. + + NOTE: On windows, system configuration cannot currently be read as the path is + unknown, instead the global path will be used. + + ``config_level`` + For possible values, see config_writer method + If None, all applicable levels will be used. Specify a level in case + you know which exact file you whish to read to prevent reading multiple files for + instance + """ + files = None + if config_level is None: + files = [ self._get_config_path(f) for f in self.config_level ] + else: + files = [ self._get_config_path(config_level) ] + return GitConfigParser(files, read_only=True) + + def config_writer(self, config_level="repository"): + """ + Returns + GitConfigParser allowing to write values of the specified configuration file level. + Config writers should be retrieved, used to change the configuration ,and written + right away as they will lock the configuration file in question and prevent other's + to write it. + + ``config_level`` + One of the following values + system = sytem wide configuration file + global = user level configuration file + repository = configuration file for this repostory only + """ + return GitConfigParser(self._get_config_path(config_level), read_only = False) def commit(self, rev=None): """ @@ -129,16 +336,25 @@ class Repo(object): if rev is None: rev = self.active_branch - c = Object(self, rev) + c = Object.new(self, rev) assert c.type == "commit", "Revision %s did not point to a commit, but to %s" % (rev, c) return c + + def iter_trees(self, *args, **kwargs): + """ + Returns + Iterator yielding Tree objects + + Note: Takes all arguments known to iter_commits method + """ + return ( c.tree for c in self.iter_commits(*args, **kwargs) ) - def tree(self, ref=None): + def tree(self, rev=None): """ - The Tree object for the given treeish reference + The Tree object for the given treeish revision - ``ref`` - is a Ref instance defaulting to the active_branch if None. + ``rev`` + is a revision pointing to a Treeish ( being a commit or tree ) Examples:: @@ -148,32 +364,19 @@ class Repo(object): ``git.Tree`` NOTE - A ref is requried here to assure you point to a commit or tag. Otherwise - it is not garantueed that you point to the root-level tree. - If you need a non-root level tree, find it by iterating the root tree. Otherwise it cannot know about its path relative to the repository root and subsequent operations might have unexpected results. """ - if ref is None: - ref = self.active_branch - if not isinstance(ref, Reference): - raise ValueError( "Reference required, got %r" % ref ) - - - # As we are directly reading object information, we must make sure - # we truly point to a tree object. We resolve the ref to a sha in all cases - # to assure the returned tree can be compared properly. Except for - # heads, ids should always be hexshas - hexsha, typename, size = self.git.get_object_header( ref ) - if typename != "tree": - # will raise if this is not a valid tree - hexsha, typename, size = self.git.get_object_header( str(ref)+'^{tree}' ) - # END tree handling - ref = hexsha + if rev is None: + rev = self.active_branch - # the root has an empty relative path and the default mode - return Tree(self, ref, 0, '') + c = Object.new(self, rev) + if c.type == "commit": + return c.tree + elif c.type == "tree": + return c + raise ValueError( "Revision %s did not point to a treeish, but to %s" % (rev, c)) def iter_commits(self, rev=None, paths='', **kwargs): """ @@ -249,42 +452,88 @@ class Repo(object): Raises NoSuchPathError + Note + The method does not check for the existance of the paths in alts + as the caller is responsible. + Returns None """ - for alt in alts: - if not os.path.exists(alt): - raise NoSuchPathError("Could not set alternates. Alternate path %s must exist" % alt) - + alternates_path = os.path.join(self.path, 'objects', 'info', 'alternates') if not alts: - os.remove(os.path.join(self.path, 'objects', 'info', 'alternates')) + if os.path.isfile(alternates_path): + os.remove(alternates_path) else: try: - f = open(os.path.join(self.path, 'objects', 'info', 'alternates'), 'w') + f = open(alternates_path, 'w') f.write("\n".join(alts)) finally: f.close() + # END file handling + # END alts handling alternates = property(_get_alternates, _set_alternates, doc="Retrieve a list of alternates paths or set a list paths to be used as alternates") - @property - def is_dirty(self): + def is_dirty(self, index=True, working_tree=True, untracked_files=False): """ - Return the status of the index. - Returns - ``True``, if the index has any uncommitted changes, - otherwise ``False`` - - NOTE - Working tree changes that have not been staged will not be detected ! + ``True``, the repository is considered dirty. By default it will react + like a git-status without untracked files, hence it is dirty if the + index or the working copy have changes. """ - if self.bare: + if self._bare: # Bare repositories with no associated working directory are # always consired to be clean. return False - - return len(self.git.diff('HEAD', '--').strip()) > 0 + + # start from the one which is fastest to evaluate + default_args = ('--abbrev=40', '--full-index', '--raw') + if index: + # diff index against HEAD + if len(self.git.diff('HEAD', '--cached', *default_args)): + return True + # END index handling + if working_tree: + # diff index against working tree + if len(self.git.diff(*default_args)): + return True + # END working tree handling + if untracked_files: + if len(self.untracked_files): + return True + # END untracked files + return False + + @property + def untracked_files(self): + """ + Returns + list(str,...) + + Files currently untracked as they have not been staged yet. Paths + are relative to the current working directory of the git command. + + Note + ignored files will not appear here, i.e. files mentioned in .gitignore + """ + # make sure we get all files, no only untracked directores + proc = self.git.status(untracked_files=True, as_process=True) + stream = iter(proc.stdout) + untracked_files = list() + for line in stream: + if not line.startswith("# Untracked files:"): + continue + # skip two lines + stream.next() + stream.next() + + for untracked_info in stream: + if not untracked_info.startswith("#\t"): + break + untracked_files.append(untracked_info.replace("#\t", "").rstrip()) + # END for each utracked info line + # END for each line + return untracked_files @property def active_branch(self): @@ -294,36 +543,8 @@ class Repo(object): Returns Head to the active branch """ - return Head( self, self.git.symbolic_ref('HEAD').strip() ) - - - def diff(self, a, b, *paths): - """ - The diff from commit ``a`` to commit ``b``, optionally restricted to the given file(s) - - ``a`` - is the base commit - ``b`` - is the other commit - - ``paths`` - is an optional list of file paths on which to restrict the diff + return self.head.reference - Returns - ``str`` - """ - return self.git.diff(a, b, '--', *paths) - - def commit_diff(self, commit): - """ - The commit diff for the given commit - ``commit`` is the commit name/id - - Returns - ``git.Diff[]`` - """ - return Commit.diff(self, commit) - def blame(self, rev, file): """ The blame information for the given file at the given revision. @@ -388,7 +609,7 @@ class Repo(object): sha = info['id'] c = commits.get(sha) if c is None: - c = Commit( self, id=sha, + c = Commit( self, sha, author=Actor._from_string(info['author'] + ' ' + info['author_email']), authored_date=info['author_date'], committer=Actor._from_string(info['committer'] + ' ' + info['committer_email']), @@ -445,7 +666,7 @@ class Repo(object): Create a clone from this repository. ``path`` - is the full path of the new repo (traditionally ends with /<name>.git) + is the full path of the new repo (traditionally ends with ./<name>.git). ``kwargs`` keyword arguments to be given to the git-clone command @@ -476,32 +697,23 @@ class Repo(object): Examples:: - >>> repo.archive(open("archive" + >>> repo.archive(open("archive")) <String containing tar.gz archive> - >>> repo.archive_tar_gz('a87ff14') - <String containing tar.gz archive for commit a87ff14> - - >>> repo.archive_tar_gz('master', 'myproject/') - <String containing tar.gz archive and prefixed with 'myproject/'> - Raise GitCommandError in case something went wrong + Returns + self """ if treeish is None: treeish = self.active_branch if prefix and 'prefix' not in kwargs: kwargs['prefix'] = prefix - kwargs['as_process'] = True kwargs['output_stream'] = ostream - proc = self.git.archive(treeish, **kwargs) - status = proc.wait() - if status != 0: - raise GitCommandError( "git-archive", status, proc.stderr.read() ) - - + self.git.archive(treeish, **kwargs) + return self def __repr__(self): return '<git.Repo "%s">' % self.path diff --git a/lib/git/utils.py b/lib/git/utils.py index f84c247d..48427ff2 100644 --- a/lib/git/utils.py +++ b/lib/git/utils.py @@ -5,27 +5,223 @@ # the BSD License: http://www.opensource.org/licenses/bsd-license.php import os +import sys +import tempfile -def dashify(string): - return string.replace('_', '-') +try: + import hashlib +except ImportError: + import sha -def touch(filename): - os.utime(filename) +def make_sha(source=''): + """ + A python2.4 workaround for the sha/hashlib module fiasco + + Note + From the dulwich project + """ + try: + return hashlib.sha1(source) + except NameError: + sha1 = sha.sha(source) + return sha1 + + +class SHA1Writer(object): + """ + Wrapper around a file-like object that remembers the SHA1 of + the data written to it. It will write a sha when the stream is closed + or if the asked for explicitly usign write_sha. + + Note: + Based on the dulwich project + """ + __slots__ = ("f", "sha1") + + def __init__(self, f): + self.f = f + self.sha1 = make_sha("") + + def write(self, data): + self.sha1.update(data) + self.f.write(data) + + def write_sha(self): + sha = self.sha1.digest() + self.f.write(sha) + return sha + + def close(self): + sha = self.write_sha() + self.f.close() + return sha + + def tell(self): + return self.f.tell() + + +class LockFile(object): + """ + Provides methods to obtain, check for, and release a file based lock which + should be used to handle concurrent access to the same file. + + As we are a utility class to be derived from, we only use protected methods. + + Locks will automatically be released on destruction + """ + __slots__ = ("_file_path", "_owns_lock") + + def __init__(self, file_path): + self._file_path = file_path + self._owns_lock = False + + def __del__(self): + self._release_lock() + + def _lock_file_path(self): + """ + Return + Path to lockfile + """ + return "%s.lock" % (self._file_path) + + def _has_lock(self): + """ + Return + True if we have a lock and if the lockfile still exists + + Raise + AssertionError if our lock-file does not exist + """ + if not self._owns_lock: + return False + + lock_file = self._lock_file_path() + try: + fp = open(lock_file, "r") + pid = int(fp.read()) + fp.close() + except IOError: + raise AssertionError("GitConfigParser has a lock but the lock file at %s could not be read" % lock_file) + + if pid != os.getpid(): + raise AssertionError("We claim to own the lock at %s, but it was not owned by our process: %i" % (lock_file, os.getpid())) + + return True + + def _obtain_lock_or_raise(self): + """ + Create a lock file as flag for other instances, mark our instance as lock-holder + + Raise + IOError if a lock was already present or a lock file could not be written + """ + if self._has_lock(): + return + + lock_file = self._lock_file_path() + if os.path.exists(lock_file): + raise IOError("Lock for file %r did already exist, delete %r in case the lock is illegal" % (self._file_path, lock_file)) + + fp = open(lock_file, "w") + fp.write(str(os.getpid())) + fp.close() + + self._owns_lock = True + + def _release_lock(self): + """ + Release our lock if we have one + """ + if not self._has_lock(): + return + + os.remove(self._lock_file_path()) + self._owns_lock = False -def is_git_dir(d): - """ This is taken from the git setup.c:is_git_directory - function.""" - if os.path.isdir(d) and \ - os.path.isdir(os.path.join(d, 'objects')) and \ - os.path.isdir(os.path.join(d, 'refs')): - headref = os.path.join(d, 'HEAD') - return os.path.isfile(headref) or \ - (os.path.islink(headref) and - os.readlink(headref).startswith('refs')) - return False +class ConcurrentWriteOperation(LockFile): + """ + This class facilitates a safe write operation to a file on disk such that we: + + - lock the original file + - write to a temporary file + - rename temporary file back to the original one on close + - unlock the original file + + This type handles error correctly in that it will assure a consistent state + on destruction + """ + __slots__ = "_temp_write_fp" + + def __init__(self, file_path): + """ + Initialize an instance able to write the given file_path + """ + super(ConcurrentWriteOperation, self).__init__(file_path) + self._temp_write_fp = None + + def __del__(self): + self._end_writing(successful=False) + + def _begin_writing(self): + """ + Begin writing our file, hence we get a lock and start writing + a temporary file in the same directory. + + Returns + File Object to write to. It is still maintained by this instance + and you do not need to manually close + """ + # already writing ? + if self._temp_write_fp is not None: + return self._temp_write_fp + + self._obtain_lock_or_raise() + dirname, basename = os.path.split(self._file_path) + self._temp_write_fp = open(tempfile.mktemp(basename, '', dirname), "w") + return self._temp_write_fp + + def _is_writing(self): + """ + Returns + True if we are currently writing a file + """ + return self._temp_write_fp is not None + def _end_writing(self, successful=True): + """ + Indicate you successfully finished writing the file to: + + - close the underlying stream + - rename the remporary file to the original one + - release our lock + """ + # did we start a write operation ? + if self._temp_write_fp is None: + return + + if not self._temp_write_fp.closed: + self._temp_write_fp.close() + + if successful: + # on windows, rename does not silently overwrite the existing one + if sys.platform == "win32": + if os.path.isfile(self._file_path): + os.remove(self._file_path) + # END remove if exists + # END win32 special handling + os.rename(self._temp_write_fp.name, self._file_path) + else: + # just delete the file so far, we failed + os.remove(self._temp_write_fp.name) + # END successful handling + # finally reset our handle + self._release_lock() + self._temp_write_fp = None + + class LazyMixin(object): """ Base class providing an interface to lazily retrieve attribute values upon @@ -56,12 +252,55 @@ class LazyMixin(object): pass +class IterableList(list): + """ + List of iterable objects allowing to query an object by id or by named index:: + + heads = repo.heads + heads.master + heads['master'] + heads[0] + + It requires an id_attribute name to be set which will be queried from its + contained items to have a means for comparison. + + A prefix can be specified which is to be used in case the id returned by the + items always contains a prefix that does not matter to the user, so it + can be left out. + """ + __slots__ = ('_id_attr', '_prefix') + + def __new__(cls, id_attr, prefix=''): + return super(IterableList,cls).__new__(cls) + + def __init__(self, id_attr, prefix=''): + self._id_attr = id_attr + self._prefix = prefix + + def __getattr__(self, attr): + attr = self._prefix + attr + for item in self: + if getattr(item, self._id_attr) == attr: + return item + # END for each item + return list.__getattribute__(self, attr) + + def __getitem__(self, index): + if isinstance(index, int): + return list.__getitem__(self,index) + + try: + return getattr(self, index) + except AttributeError: + raise IndexError( "No item found with id %r" % self._prefix + index ) + class Iterable(object): """ Defines an interface for iterable items which is to assure a uniform way to retrieve and iterate items within the git repository """ __slots__ = tuple() + _id_attribute_ = "attribute that most suitably identifies your instance" @classmethod def list_items(cls, repo, *args, **kwargs): @@ -75,7 +314,10 @@ class Iterable(object): Returns: list(Item,...) list of item instances """ - return list(cls.iter_items(repo, *args, **kwargs)) + #return list(cls.iter_items(repo, *args, **kwargs)) + out_list = IterableList( cls._id_attribute_ ) + out_list.extend(cls.iter_items(repo, *args, **kwargs)) + return out_list @classmethod |