summaryrefslogtreecommitdiff
path: root/git/objects/tree.py
diff options
context:
space:
mode:
authorSebastian Thiel <sebastian.thiel@icloud.com>2021-06-26 10:09:53 +0800
committerGitHub <noreply@github.com>2021-06-26 10:09:53 +0800
commit2d2ff037f9f7a9ae33e5f4f6bdb75b669a1af19a (patch)
tree5f4fd00ad13fa5455dc876ab9cb9cc4f9b66bdfc /git/objects/tree.py
parent703280b8c3df6f9b1a5cbe0997b717edbcaa8979 (diff)
parent5d7b8ba9f2e9298496232e4ae66bd904a1d71001 (diff)
downloadgitpython-2d2ff037f9f7a9ae33e5f4f6bdb75b669a1af19a.tar.gz
Merge pull request #1279 from Yobmod/main
Finish typing object, improve verious other types.
Diffstat (limited to 'git/objects/tree.py')
-rw-r--r--git/objects/tree.py91
1 files changed, 50 insertions, 41 deletions
diff --git a/git/objects/tree.py b/git/objects/tree.py
index 29b2a684..191fe27c 100644
--- a/git/objects/tree.py
+++ b/git/objects/tree.py
@@ -20,21 +20,27 @@ from .fun import (
# typing -------------------------------------------------
-from typing import Iterable, Iterator, Tuple, Union, cast, TYPE_CHECKING
+from typing import Callable, Dict, Generic, Iterable, Iterator, List, Tuple, Type, TypeVar, Union, cast, TYPE_CHECKING
+
+from git.types import PathLike
if TYPE_CHECKING:
+ from git.repo import Repo
from io import BytesIO
#--------------------------------------------------------
-cmp = lambda a, b: (a > b) - (a < b)
+cmp: Callable[[str, str], int] = lambda a, b: (a > b) - (a < b)
__all__ = ("TreeModifier", "Tree")
+T_Tree_cache = TypeVar('T_Tree_cache', bound=Union[Tuple[bytes, int, str]])
+
-def git_cmp(t1, t2):
+def git_cmp(t1: T_Tree_cache, t2: T_Tree_cache) -> int:
a, b = t1[2], t2[2]
+ assert isinstance(a, str) and isinstance(b, str) # Need as mypy 9.0 cannot unpack TypeVar properly
len_a, len_b = len(a), len(b)
min_len = min(len_a, len_b)
min_cmp = cmp(a[:min_len], b[:min_len])
@@ -45,9 +51,10 @@ def git_cmp(t1, t2):
return len_a - len_b
-def merge_sort(a, cmp):
+def merge_sort(a: List[T_Tree_cache],
+ cmp: Callable[[T_Tree_cache, T_Tree_cache], int]) -> None:
if len(a) < 2:
- return
+ return None
mid = len(a) // 2
lefthalf = a[:mid]
@@ -80,7 +87,7 @@ def merge_sort(a, cmp):
k = k + 1
-class TreeModifier(object):
+class TreeModifier(Generic[T_Tree_cache], object):
"""A utility class providing methods to alter the underlying cache in a list-like fashion.
@@ -88,10 +95,10 @@ class TreeModifier(object):
the cache of a tree, will be sorted. Assuring it will be in a serializable state"""
__slots__ = '_cache'
- def __init__(self, cache):
+ def __init__(self, cache: List[T_Tree_cache]) -> None:
self._cache = cache
- def _index_by_name(self, name):
+ def _index_by_name(self, name: str) -> int:
""":return: index of an item with name, or -1 if not found"""
for i, t in enumerate(self._cache):
if t[2] == name:
@@ -101,7 +108,7 @@ class TreeModifier(object):
return -1
#{ Interface
- def set_done(self):
+ def set_done(self) -> 'TreeModifier':
"""Call this method once you are done modifying the tree information.
It may be called several times, but be aware that each call will cause
a sort operation
@@ -111,7 +118,7 @@ class TreeModifier(object):
#} END interface
#{ Mutators
- def add(self, sha, mode, name, force=False):
+ def add(self, sha: bytes, mode: int, name: str, force: bool = False) -> 'TreeModifier':
"""Add the given item to the tree. If an item with the given name already
exists, nothing will be done, but a ValueError will be raised if the
sha and mode of the existing item do not match the one you add, unless
@@ -129,7 +136,9 @@ class TreeModifier(object):
sha = to_bin_sha(sha)
index = self._index_by_name(name)
- item = (sha, mode, name)
+
+ assert isinstance(sha, bytes) and isinstance(mode, int) and isinstance(name, str)
+ item = cast(T_Tree_cache, (sha, mode, name)) # use Typeguard from typing-extensions 3.10.0
if index == -1:
self._cache.append(item)
else:
@@ -144,14 +153,17 @@ class TreeModifier(object):
# END handle name exists
return self
- def add_unchecked(self, binsha, mode, name):
+ def add_unchecked(self, binsha: bytes, mode: int, name: str) -> None:
"""Add the given item to the tree, its correctness is assumed, which
puts the caller into responsibility to assure the input is correct.
For more information on the parameters, see ``add``
:param binsha: 20 byte binary sha"""
- self._cache.append((binsha, mode, name))
+ assert isinstance(binsha, bytes) and isinstance(mode, int) and isinstance(name, str)
+ tree_cache = cast(T_Tree_cache, (binsha, mode, name))
- def __delitem__(self, name):
+ self._cache.append(tree_cache)
+
+ def __delitem__(self, name: str) -> None:
"""Deletes an item with the given name if it exists"""
index = self._index_by_name(name)
if index > -1:
@@ -182,29 +194,29 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
symlink_id = 0o12
tree_id = 0o04
- _map_id_to_type = {
+ _map_id_to_type: Dict[int, Union[Type[Submodule], Type[Blob], Type['Tree']]] = {
commit_id: Submodule,
blob_id: Blob,
symlink_id: Blob
# tree id added once Tree is defined
}
- def __init__(self, repo, binsha, mode=tree_id << 12, path=None):
+ def __init__(self, repo: 'Repo', binsha: bytes, mode: int = tree_id << 12, path: Union[PathLike, None] = None):
super(Tree, self).__init__(repo, binsha, mode, path)
- @classmethod
+ @ classmethod
def _get_intermediate_items(cls, index_object: 'Tree', # type: ignore
- ) -> Tuple['Tree', ...]:
+ ) -> Union[Tuple['Tree', ...], Tuple[()]]:
if index_object.type == "tree":
index_object = cast('Tree', index_object)
return tuple(index_object._iter_convert_to_object(index_object._cache))
return ()
- def _set_cache_(self, attr):
+ def _set_cache_(self, attr: str) -> None:
if attr == "_cache":
# Set the data when we need it
ostream = self.repo.odb.stream(self.binsha)
- self._cache = tree_entries_from_data(ostream.read())
+ self._cache: List[Tuple[bytes, int, str]] = tree_entries_from_data(ostream.read())
else:
super(Tree, self)._set_cache_(attr)
# END handle attribute
@@ -221,7 +233,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e
# END for each item
- def join(self, file):
+ def join(self, file: str) -> Union[Blob, 'Tree', Submodule]:
"""Find the named object in this tree's contents
:return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule``
@@ -254,26 +266,22 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
raise KeyError(msg % file)
# END handle long paths
- def __div__(self, file):
- """For PY2 only"""
- return self.join(file)
-
- def __truediv__(self, file):
+ def __truediv__(self, file: str) -> Union['Tree', Blob, Submodule]:
"""For PY3 only"""
return self.join(file)
- @property
- def trees(self):
+ @ property
+ def trees(self) -> List['Tree']:
""":return: list(Tree, ...) list of trees directly below this tree"""
return [i for i in self if i.type == "tree"]
- @property
- def blobs(self):
+ @ property
+ def blobs(self) -> List['Blob']:
""":return: list(Blob, ...) list of blobs directly below this tree"""
return [i for i in self if i.type == "blob"]
- @property
- def cache(self):
+ @ property
+ def cache(self) -> TreeModifier:
"""
:return: An object allowing to modify the internal cache. This can be used
to change the tree's contents. When done, make sure you call ``set_done``
@@ -289,16 +297,16 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
return super(Tree, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self)
# List protocol
- def __getslice__(self, i, j):
+ def __getslice__(self, i: int, j: int) -> List[Union[Blob, 'Tree', Submodule]]:
return list(self._iter_convert_to_object(self._cache[i:j]))
- def __iter__(self):
+ def __iter__(self) -> Iterator[Union[Blob, 'Tree', Submodule]]:
return self._iter_convert_to_object(self._cache)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._cache)
- def __getitem__(self, item):
+ def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submodule]:
if isinstance(item, int):
info = self._cache[item]
return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2]))
@@ -310,7 +318,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
raise TypeError("Invalid index type: %r" % item)
- def __contains__(self, item):
+ def __contains__(self, item: Union[IndexObject, PathLike]) -> bool:
if isinstance(item, IndexObject):
for info in self._cache:
if item.binsha == info[0]:
@@ -321,10 +329,11 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable):
# compatibility
# treat item as repo-relative path
- path = self.path
- for info in self._cache:
- if item == join_path(path, info[2]):
- return True
+ else:
+ path = self.path
+ for info in self._cache:
+ if item == join_path(path, info[2]):
+ return True
# END for each item
return False