diff options
Diffstat (limited to 'lib/git')
-rw-r--r-- | lib/git/objects/tree.py | 118 |
1 files changed, 63 insertions, 55 deletions
diff --git a/lib/git/objects/tree.py b/lib/git/objects/tree.py index 3c860199..67aea1cb 100644 --- a/lib/git/objects/tree.py +++ b/lib/git/objects/tree.py @@ -13,12 +13,13 @@ import git.diff as diff import utils from git.utils import join_path +join = os.path.join def sha_to_hex(sha): """Takes a string and returns the hex of the sha within""" hexsha = binascii.hexlify(sha) return hexsha - + class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializable): """ @@ -47,6 +48,13 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl symlink_id = 012 tree_id = 004 + _map_id_to_type = { + commit_id : Submodule, + blob_id : Blob, + symlink_id : Blob + # tree id added once Tree is defined + } + def __init__(self, repo, sha, mode=0, path=None): super(Tree, self).__init__(repo, sha, mode, path) @@ -54,31 +62,32 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl @classmethod def _get_intermediate_items(cls, index_object): if index_object.type == "tree": - return index_object._cache + return tuple(index_object._iter_convert_to_object(index_object._cache)) return tuple() - def _set_cache_(self, attr): if attr == "_cache": # Set the data when we need it - self._cache = self._get_tree_cache() + self._cache = self._get_tree_cache(self.data) else: super(Tree, self)._set_cache_(attr) - def _get_tree_cache(self, data=None): + def _get_tree_cache(self, data): """ :return: list(object_instance, ...) - :param data: if not None, a byte string representing the tree data - If None, self.data will be used instead""" - out = list() - if data is None: - data = self.data - for obj in self._iter_from_data(data): - if obj is not None: - out.append(obj) - # END if object was handled - # END for each line from ls-tree - return out + :param data: data string containing our serialized information""" + return list(self._iter_from_data(data)) + def _iter_convert_to_object(self, iterable): + """Iterable yields tuples of (hexsha, mode, name), which will be converted + to the respective object representation""" + for hexsha, mode, name in iterable: + path = join(self.path, name) + type_id = mode >> 12 + try: + yield self._map_id_to_type[type_id](self.repo, hexsha, mode, path) + except KeyError: + raise TypeError( "Unknown type %i found in tree data for path '%s'" % (type_id, path)) + # END for each item def _iter_from_data(self, data): """ @@ -87,8 +96,7 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl Note: This method was inspired by the parse_tree method in dulwich. - Returns - list(IndexObject, ...) + :yield: Tuple(hexsha, mode, tree_relative_path) """ ord_zero = ord('0') len_data = len(data) @@ -105,7 +113,6 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl mode = (mode << 3) + (ord(data[i]) - ord_zero) i += 1 # END while reading mode - type_id = mode >> 12 # byte is space now, skip it i += 1 @@ -117,22 +124,13 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl i += 1 # END while not reached NULL name = data[ns:i] - path = join_path(self.path, name) # byte is NULL, get next 20 i += 1 sha = data[i:i+20] i = i + 20 - hexsha = sha_to_hex(sha) - if type_id == self.blob_id or type_id == self.symlink_id: - yield Blob(self.repo, hexsha, mode, path) - elif type_id == self.tree_id: - yield Tree(self.repo, hexsha, mode, path) - elif type_id == self.commit_id: - yield Submodule(self.repo, hexsha, mode, path) - else: - raise TypeError( "Unknown type found in tree data %i for path '%s'" % (type_id, path)) + yield (sha_to_hex(sha), mode, name) # END for each byte in data stream @@ -165,7 +163,7 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl else: # safety assertion - blobs are at the end of the path if i != len(tokens)-1: - raise KeyError(msg % file) + raise KeyError(msg % file) return item # END handle item type # END for each token of split path @@ -173,9 +171,9 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl raise KeyError(msg % file) return item else: - for obj in self._cache: - if obj.name == file: - return obj + for info in self._cache: + if info[2] == file: # [2] == name + return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join(self.path, info[2])) # END for each obj raise KeyError( msg % file ) # END handle long paths @@ -210,18 +208,19 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl return super(Tree, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self) # List protocol - def __getslice__(self,i,j): - return self._cache[i:j] + def __getslice__(self, i, j): + return list(self._iter_convert_to_object(self._cache[i:j])) def __iter__(self): - return iter(self._cache) + return self._iter_convert_to_object(self._cache) def __len__(self): return len(self._cache) - def __getitem__(self,item): + def __getitem__(self, item): if isinstance(item, int): - return self._cache[item] + info = self._cache[item] + return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join(self.path, info[2])) if isinstance(item, basestring): # compatability @@ -231,19 +230,26 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl raise TypeError( "Invalid index type: %r" % item ) - def __contains__(self,item): + def __contains__(self, item): if isinstance(item, base.IndexObject): - return item in self._cache - + for info in self._cache: + if item.sha == info[0]: + return True + # END compare sha + # END for each entry + # END handle item is index object # compatability - for obj in self._cache: - if item == obj.path: + + # treat item as repo-relative path + path = self.path + for info in self._cache: + if item == join(path, info[2]): return True # END for each item return False def __reversed__(self): - return reversed(self._cache) + return reversed(self._iter_convert_to_object(self._cache)) def _serialize(self, stream, presort=False): """Serialize this tree into the stream. Please note that we will assume @@ -256,25 +262,27 @@ class Tree(base.IndexObject, diff.Diffable, utils.Traversable, utils.Serializabl bit_mask = 7 # 3 bits set hex_to_bin = binascii.a2b_hex - for item in self._cache: - mode = '' - mb = item.mode + for hexsha, mode, name in self._cache: + mode_str = '' for i in xrange(6): - mode = chr(((mb >> (i*3)) & bit_mask) + ord_zero) + mode + mode_str = chr(((mode >> (i*3)) & bit_mask) + ord_zero) + mode_str # END for each 8 octal value + # git slices away the first octal if its zero - if mode[0] == '0': - mode = mode[1:] + if mode_str[0] == '0': + mode_str = mode_str[1:] # END save a byte - # note: the cache currently contains repo-relative paths, not - # tree-relative ones. Maybe the cache should only contain - # actual tuples, which are converted to objects later - # TODO: do it so - stream.write("%s %s\0%s" % (mode, os.path.basename(item.path), hex_to_bin(item.sha))) + stream.write("%s %s\0%s" % (mode_str, name, hex_to_bin(hexsha))) # END for each item return self def _deserialize(self, stream): self._cache = self._get_tree_cache(stream.read()) return self + + +# END tree + +# finalize map definition +Tree._map_id_to_type[Tree.tree_id] = Tree |