diff options
-rw-r--r-- | git/objects/tree.py | 7 | ||||
-rw-r--r-- | git/objects/util.py | 7 |
2 files changed, 11 insertions, 3 deletions
diff --git a/git/objects/tree.py b/git/objects/tree.py index d3681e23..804554d8 100644 --- a/git/objects/tree.py +++ b/git/objects/tree.py @@ -4,7 +4,7 @@ # This module is part of GitPython and is released under # the BSD License: http://www.opensource.org/licenses/bsd-license.php -from git.util import join_path +from git.util import IterableList, join_path import git.diff as diff from git.util import to_bin_sha @@ -21,7 +21,7 @@ from .fun import ( # typing ------------------------------------------------- -from typing import (Callable, Dict, Iterable, Iterator, List, +from typing import (Any, Callable, Dict, Iterable, Iterator, List, Tuple, Type, Union, cast, TYPE_CHECKING) from git.types import PathLike, TypeGuard @@ -323,6 +323,9 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): super(Tree, self).traverse(predicate, prune, depth, # type: ignore branch_first, visit_once, ignore_self)) + def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList[Union['Tree', 'Submodule', 'Blob']]: + return super(Tree, self).list_traverse(* args, **kwargs) + # List protocol def __getslice__(self, i: int, j: int) -> List[IndexObjUnion]: diff --git a/git/objects/util.py b/git/objects/util.py index 982e7ac7..4dce0aee 100644 --- a/git/objects/util.py +++ b/git/objects/util.py @@ -19,6 +19,8 @@ import time import calendar from datetime import datetime, timedelta, tzinfo +from git.objects.base import IndexObject # just for an isinstance check + # typing ------------------------------------------------------------ from typing import (Any, Callable, Deque, Iterator, NamedTuple, overload, Sequence, TYPE_CHECKING, Tuple, Type, TypeVar, Union, cast) @@ -317,7 +319,7 @@ class Traversable(object): """ # Commit and Submodule have id.__attribute__ as IterableObj # Tree has id.__attribute__ inherited from IndexObject - if isinstance(self, (TraversableIterableObj, Tree)): + if isinstance(self, (TraversableIterableObj, IndexObject)): id = self._id_attribute_ else: id = "" # shouldn't reach here, unless Traversable subclass created with no _id_attribute_ @@ -456,6 +458,9 @@ class TraversableIterableObj(Traversable, IterableObj): TIobj_tuple = Tuple[Union[T_TIobj, None], T_TIobj] + def list_traverse(self: T_TIobj, *args: Any, **kwargs: Any) -> IterableList[T_TIobj]: # type: ignore[override] + return super(TraversableIterableObj, self).list_traverse(* args, **kwargs) + @ overload # type: ignore def traverse(self: T_TIobj, predicate: Callable[[Union[T_TIobj, Tuple[Union[T_TIobj, None], T_TIobj]], int], bool], |