diff options
-rw-r--r-- | lib/git/cmd.py | 5 | ||||
-rw-r--r-- | lib/git/objects/tree.py | 2 | ||||
-rw-r--r-- | lib/git/repo.py | 39 | ||||
-rw-r--r-- | test/git/test_repo.py | 8 | ||||
-rw-r--r-- | test/git/test_tree.py | 2 |
5 files changed, 26 insertions, 30 deletions
diff --git a/lib/git/cmd.py b/lib/git/cmd.py index 4b4b84af..fb6f2998 100644 --- a/lib/git/cmd.py +++ b/lib/git/cmd.py @@ -336,8 +336,9 @@ class Git(object): """ tokens = header_line.split() if len(tokens) != 3: - raise ValueError( "SHA named %s could not be resolved" % tokens[0] ) - + raise ValueError("SHA named %s could not be resolved" % tokens[0] ) + if len(tokens[0]) != 40: + raise ValueError("Failed to parse header: %r" % header_line) return (tokens[0], tokens[1], int(tokens[2])) def __prepare_ref(self, ref): diff --git a/lib/git/objects/tree.py b/lib/git/objects/tree.py index bcb805af..27bd84d0 100644 --- a/lib/git/objects/tree.py +++ b/lib/git/objects/tree.py @@ -226,7 +226,7 @@ class Tree(base.IndexObject, diff.Diffable): if isinstance(item, basestring): # compatability for obj in self._cache: - if obj.path == item: + if obj.name == item: return obj # END for each obj raise KeyError( "Blob or Tree named %s not found" % item ) diff --git a/lib/git/repo.py b/lib/git/repo.py index 41484aa0..6d388633 100644 --- a/lib/git/repo.py +++ b/lib/git/repo.py @@ -349,12 +349,12 @@ class Repo(object): """ return ( c.tree for c in self.iter_commits(*args, **kwargs) ) - def tree(self, ref=None): + def tree(self, rev=None): """ - The Tree object for the given treeish reference + The Tree object for the given treeish revision - ``ref`` - is a Ref instance defaulting to the active_branch if None. + ``rev`` + is a revision pointing to a Treeish ( being a commit or tree ) Examples:: @@ -364,32 +364,19 @@ class Repo(object): ``git.Tree`` NOTE - A ref is requried here to assure you point to a commit or tag. Otherwise - it is not garantueed that you point to the root-level tree. - If you need a non-root level tree, find it by iterating the root tree. Otherwise it cannot know about its path relative to the repository root and subsequent operations might have unexpected results. """ - if ref is None: - ref = self.active_branch - if not isinstance(ref, Reference): - raise ValueError( "Reference required, got %r" % ref ) - - - # As we are directly reading object information, we must make sure - # we truly point to a tree object. We resolve the ref to a sha in all cases - # to assure the returned tree can be compared properly. Except for - # heads, ids should always be hexshas - hexsha, typename, size = self.git.get_object_header( ref ) - if typename != "tree": - # will raise if this is not a valid tree - hexsha, typename, size = self.git.get_object_header( str(ref)+'^{tree}' ) - # END tree handling - ref = hexsha - - # the root has an empty relative path and the default mode - return Tree(self, ref, 0, '') + if rev is None: + rev = self.active_branch + + c = Object.new(self, rev) + if c.type == "commit": + return c.tree + elif c.type == "tree": + return c + raise ValueError( "Revision %s did not point to a treeish, but to %s" % (rev, c)) def iter_commits(self, rev=None, paths='', **kwargs): """ diff --git a/test/git/test_repo.py b/test/git/test_repo.py index df495e71..0b196a1f 100644 --- a/test/git/test_repo.py +++ b/test/git/test_repo.py @@ -38,6 +38,14 @@ class TestRepo(TestBase): assert isinstance(self.rorepo.heads.master, Head) assert isinstance(self.rorepo.heads['master'], Head) + + def test_tree_from_revision(self): + tree = self.rorepo.tree('0.1.6') + assert tree.type == "tree" + assert self.rorepo.tree(tree) == tree + + # try from invalid revision that does not exist + self.failUnlessRaises(ValueError, self.rorepo.tree, 'hello world') @patch_object(Git, '_call_process') def test_commits(self, git): diff --git a/test/git/test_tree.py b/test/git/test_tree.py index 7b66743f..e0c1f134 100644 --- a/test/git/test_tree.py +++ b/test/git/test_tree.py @@ -16,7 +16,7 @@ class TestTree(TestCase): def test_traverse(self): - root = self.repo.tree() + root = self.repo.tree('0.1.6') num_recursive = 0 all_items = list() for obj in root.traverse(): |