summaryrefslogtreecommitdiff
path: root/git/objects/tree.py
diff options
context:
space:
mode:
Diffstat (limited to 'git/objects/tree.py')
-rw-r--r--git/objects/tree.py56
1 files changed, 28 insertions, 28 deletions
diff --git a/git/objects/tree.py b/git/objects/tree.py
index 29b2a684..ec7d8e88 100644
--- a/git/objects/tree.py
+++ b/git/objects/tree.py
@@ -20,20 +20,23 @@ from .fun import (
# typing -------------------------------------------------
-from typing import Iterable, Iterator, Tuple, Union, cast, TYPE_CHECKING
+from typing import Callable, Dict, Iterable, Iterator, List, Tuple, Type, Union, cast, TYPE_CHECKING
+
+from git.types import PathLike
if TYPE_CHECKING:
+ from git.repo import Repo
from io import BytesIO
#--------------------------------------------------------
-cmp = lambda a, b: (a > b) - (a < b)
+cmp: Callable[[int, int], int] = lambda a, b: (a > b) - (a < b)
__all__ = ("TreeModifier", "Tree")
-def git_cmp(t1, t2):
+def git_cmp(t1: 'Tree', t2: 'Tree') -> int:
a, b = t1[2], t2[2]
len_a, len_b = len(a), len(b)
min_len = min(len_a, len_b)
@@ -45,9 +48,9 @@ def git_cmp(t1, t2):
return len_a - len_b
-def merge_sort(a, cmp):
+def merge_sort(a: List[int], cmp: Callable[[int, int], int]) -> None:
if len(a) < 2:
- return
+ return None
mid = len(a) // 2
lefthalf = a[:mid]
@@ -182,29 +185,29 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
symlink_id = 0o12
tree_id = 0o04
- _map_id_to_type = {
+ _map_id_to_type: Dict[int, Union[Type[Submodule], Type[Blob], Type['Tree']]] = {
commit_id: Submodule,
blob_id: Blob,
symlink_id: Blob
# tree id added once Tree is defined
}
- def __init__(self, repo, binsha, mode=tree_id << 12, path=None):
+ def __init__(self, repo: 'Repo', binsha: bytes, mode: int = tree_id << 12, path: Union[PathLike, None] = None):
super(Tree, self).__init__(repo, binsha, mode, path)
@classmethod
def _get_intermediate_items(cls, index_object: 'Tree', # type: ignore
- ) -> Tuple['Tree', ...]:
+ ) -> Union[Tuple['Tree', ...], Tuple[()]]:
if index_object.type == "tree":
index_object = cast('Tree', index_object)
return tuple(index_object._iter_convert_to_object(index_object._cache))
return ()
- def _set_cache_(self, attr):
+ def _set_cache_(self, attr: str) -> None:
if attr == "_cache":
# Set the data when we need it
ostream = self.repo.odb.stream(self.binsha)
- self._cache = tree_entries_from_data(ostream.read())
+ self._cache: List[Tuple[bytes, int, str]] = tree_entries_from_data(ostream.read())
else:
super(Tree, self)._set_cache_(attr)
# END handle attribute
@@ -221,7 +224,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e
# END for each item
- def join(self, file):
+ def join(self, file: str) -> Union[Blob, 'Tree', Submodule]:
"""Find the named object in this tree's contents
:return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule``
@@ -254,26 +257,22 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
raise KeyError(msg % file)
# END handle long paths
- def __div__(self, file):
- """For PY2 only"""
- return self.join(file)
-
- def __truediv__(self, file):
+ def __truediv__(self, file: str) -> Union['Tree', Blob, Submodule]:
"""For PY3 only"""
return self.join(file)
@property
- def trees(self):
+ def trees(self) -> List['Tree']:
""":return: list(Tree, ...) list of trees directly below this tree"""
return [i for i in self if i.type == "tree"]
@property
- def blobs(self):
+ def blobs(self) -> List['Blob']:
""":return: list(Blob, ...) list of blobs directly below this tree"""
return [i for i in self if i.type == "blob"]
@property
- def cache(self):
+ def cache(self) -> TreeModifier:
"""
:return: An object allowing to modify the internal cache. This can be used
to change the tree's contents. When done, make sure you call ``set_done``
@@ -289,16 +288,16 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
return super(Tree, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self)
# List protocol
- def __getslice__(self, i, j):
+ def __getslice__(self, i: int, j: int) -> List[Union[Blob, 'Tree', Submodule]]:
return list(self._iter_convert_to_object(self._cache[i:j]))
- def __iter__(self):
+ def __iter__(self) -> Iterator[Union[Blob, 'Tree', Submodule]]:
return self._iter_convert_to_object(self._cache)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._cache)
- def __getitem__(self, item):
+ def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submodule]:
if isinstance(item, int):
info = self._cache[item]
return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2]))
@@ -310,7 +309,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
raise TypeError("Invalid index type: %r" % item)
- def __contains__(self, item):
+ def __contains__(self, item: Union[IndexObject, PathLike]) -> bool:
if isinstance(item, IndexObject):
for info in self._cache:
if item.binsha == info[0]:
@@ -321,10 +320,11 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
# compatibility
# treat item as repo-relative path
- path = self.path
- for info in self._cache:
- if item == join_path(path, info[2]):
- return True
+ else:
+ path = self.path
+ for info in self._cache:
+ if item == join_path(path, info[2]):
+ return True
# END for each item
return False