diff options
Diffstat (limited to 'git/objects/tree.py')
-rw-r--r-- | git/objects/tree.py | 91 |
1 files changed, 50 insertions, 41 deletions
diff --git a/git/objects/tree.py b/git/objects/tree.py index 29b2a684..191fe27c 100644 --- a/git/objects/tree.py +++ b/git/objects/tree.py @@ -20,21 +20,27 @@ from .fun import ( # typing ------------------------------------------------- -from typing import Iterable, Iterator, Tuple, Union, cast, TYPE_CHECKING +from typing import Callable, Dict, Generic, Iterable, Iterator, List, Tuple, Type, TypeVar, Union, cast, TYPE_CHECKING + +from git.types import PathLike if TYPE_CHECKING: + from git.repo import Repo from io import BytesIO #-------------------------------------------------------- -cmp = lambda a, b: (a > b) - (a < b) +cmp: Callable[[str, str], int] = lambda a, b: (a > b) - (a < b) __all__ = ("TreeModifier", "Tree") +T_Tree_cache = TypeVar('T_Tree_cache', bound=Union[Tuple[bytes, int, str]]) + -def git_cmp(t1, t2): +def git_cmp(t1: T_Tree_cache, t2: T_Tree_cache) -> int: a, b = t1[2], t2[2] + assert isinstance(a, str) and isinstance(b, str) # Need as mypy 9.0 cannot unpack TypeVar properly len_a, len_b = len(a), len(b) min_len = min(len_a, len_b) min_cmp = cmp(a[:min_len], b[:min_len]) @@ -45,9 +51,10 @@ def git_cmp(t1, t2): return len_a - len_b -def merge_sort(a, cmp): +def merge_sort(a: List[T_Tree_cache], + cmp: Callable[[T_Tree_cache, T_Tree_cache], int]) -> None: if len(a) < 2: - return + return None mid = len(a) // 2 lefthalf = a[:mid] @@ -80,7 +87,7 @@ def merge_sort(a, cmp): k = k + 1 -class TreeModifier(object): +class TreeModifier(Generic[T_Tree_cache], object): """A utility class providing methods to alter the underlying cache in a list-like fashion. @@ -88,10 +95,10 @@ class TreeModifier(object): the cache of a tree, will be sorted. Assuring it will be in a serializable state""" __slots__ = '_cache' - def __init__(self, cache): + def __init__(self, cache: List[T_Tree_cache]) -> None: self._cache = cache - def _index_by_name(self, name): + def _index_by_name(self, name: str) -> int: """:return: index of an item with name, or -1 if not found""" for i, t in enumerate(self._cache): if t[2] == name: @@ -101,7 +108,7 @@ class TreeModifier(object): return -1 #{ Interface - def set_done(self): + def set_done(self) -> 'TreeModifier': """Call this method once you are done modifying the tree information. It may be called several times, but be aware that each call will cause a sort operation @@ -111,7 +118,7 @@ class TreeModifier(object): #} END interface #{ Mutators - def add(self, sha, mode, name, force=False): + def add(self, sha: bytes, mode: int, name: str, force: bool = False) -> 'TreeModifier': """Add the given item to the tree. If an item with the given name already exists, nothing will be done, but a ValueError will be raised if the sha and mode of the existing item do not match the one you add, unless @@ -129,7 +136,9 @@ class TreeModifier(object): sha = to_bin_sha(sha) index = self._index_by_name(name) - item = (sha, mode, name) + + assert isinstance(sha, bytes) and isinstance(mode, int) and isinstance(name, str) + item = cast(T_Tree_cache, (sha, mode, name)) # use Typeguard from typing-extensions 3.10.0 if index == -1: self._cache.append(item) else: @@ -144,14 +153,17 @@ class TreeModifier(object): # END handle name exists return self - def add_unchecked(self, binsha, mode, name): + def add_unchecked(self, binsha: bytes, mode: int, name: str) -> None: """Add the given item to the tree, its correctness is assumed, which puts the caller into responsibility to assure the input is correct. For more information on the parameters, see ``add`` :param binsha: 20 byte binary sha""" - self._cache.append((binsha, mode, name)) + assert isinstance(binsha, bytes) and isinstance(mode, int) and isinstance(name, str) + tree_cache = cast(T_Tree_cache, (binsha, mode, name)) - def __delitem__(self, name): + self._cache.append(tree_cache) + + def __delitem__(self, name: str) -> None: """Deletes an item with the given name if it exists""" index = self._index_by_name(name) if index > -1: @@ -182,29 +194,29 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): symlink_id = 0o12 tree_id = 0o04 - _map_id_to_type = { + _map_id_to_type: Dict[int, Union[Type[Submodule], Type[Blob], Type['Tree']]] = { commit_id: Submodule, blob_id: Blob, symlink_id: Blob # tree id added once Tree is defined } - def __init__(self, repo, binsha, mode=tree_id << 12, path=None): + def __init__(self, repo: 'Repo', binsha: bytes, mode: int = tree_id << 12, path: Union[PathLike, None] = None): super(Tree, self).__init__(repo, binsha, mode, path) - @classmethod + @ classmethod def _get_intermediate_items(cls, index_object: 'Tree', # type: ignore - ) -> Tuple['Tree', ...]: + ) -> Union[Tuple['Tree', ...], Tuple[()]]: if index_object.type == "tree": index_object = cast('Tree', index_object) return tuple(index_object._iter_convert_to_object(index_object._cache)) return () - def _set_cache_(self, attr): + def _set_cache_(self, attr: str) -> None: if attr == "_cache": # Set the data when we need it ostream = self.repo.odb.stream(self.binsha) - self._cache = tree_entries_from_data(ostream.read()) + self._cache: List[Tuple[bytes, int, str]] = tree_entries_from_data(ostream.read()) else: super(Tree, self)._set_cache_(attr) # END handle attribute @@ -221,7 +233,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): raise TypeError("Unknown mode %o found in tree data for path '%s'" % (mode, path)) from e # END for each item - def join(self, file): + def join(self, file: str) -> Union[Blob, 'Tree', Submodule]: """Find the named object in this tree's contents :return: ``git.Blob`` or ``git.Tree`` or ``git.Submodule`` @@ -254,26 +266,22 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): raise KeyError(msg % file) # END handle long paths - def __div__(self, file): - """For PY2 only""" - return self.join(file) - - def __truediv__(self, file): + def __truediv__(self, file: str) -> Union['Tree', Blob, Submodule]: """For PY3 only""" return self.join(file) - @property - def trees(self): + @ property + def trees(self) -> List['Tree']: """:return: list(Tree, ...) list of trees directly below this tree""" return [i for i in self if i.type == "tree"] - @property - def blobs(self): + @ property + def blobs(self) -> List['Blob']: """:return: list(Blob, ...) list of blobs directly below this tree""" return [i for i in self if i.type == "blob"] - @property - def cache(self): + @ property + def cache(self) -> TreeModifier: """ :return: An object allowing to modify the internal cache. This can be used to change the tree's contents. When done, make sure you call ``set_done`` @@ -289,16 +297,16 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): return super(Tree, self).traverse(predicate, prune, depth, branch_first, visit_once, ignore_self) # List protocol - def __getslice__(self, i, j): + def __getslice__(self, i: int, j: int) -> List[Union[Blob, 'Tree', Submodule]]: return list(self._iter_convert_to_object(self._cache[i:j])) - def __iter__(self): + def __iter__(self) -> Iterator[Union[Blob, 'Tree', Submodule]]: return self._iter_convert_to_object(self._cache) - def __len__(self): + def __len__(self) -> int: return len(self._cache) - def __getitem__(self, item): + def __getitem__(self, item: Union[str, int, slice]) -> Union[Blob, 'Tree', Submodule]: if isinstance(item, int): info = self._cache[item] return self._map_id_to_type[info[1] >> 12](self.repo, info[0], info[1], join_path(self.path, info[2])) @@ -310,7 +318,7 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): raise TypeError("Invalid index type: %r" % item) - def __contains__(self, item): + def __contains__(self, item: Union[IndexObject, PathLike]) -> bool: if isinstance(item, IndexObject): for info in self._cache: if item.binsha == info[0]: @@ -321,10 +329,11 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): # compatibility # treat item as repo-relative path - path = self.path - for info in self._cache: - if item == join_path(path, info[2]): - return True + else: + path = self.path + for info in self._cache: + if item == join_path(path, info[2]): + return True # END for each item return False |