summaryrefslogtreecommitdiff
path: root/git/objects/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'git/objects/util.py')
-rw-r--r--git/objects/util.py24
1 files changed, 10 insertions, 14 deletions
diff --git a/git/objects/util.py b/git/objects/util.py
index fce62af2..bc6cdf8f 100644
--- a/git/objects/util.py
+++ b/git/objects/util.py
@@ -23,7 +23,7 @@ from datetime import datetime, timedelta, tzinfo
from typing import (Any, Callable, Deque, Iterator, NamedTuple, overload, Sequence,
TYPE_CHECKING, Tuple, Type, TypeVar, Union, cast)
-from git.types import Literal
+from git.types import Literal, TypeGuard
if TYPE_CHECKING:
from io import BytesIO, StringIO
@@ -306,24 +306,20 @@ class Traversable(object):
"""
raise NotImplementedError("To be implemented in subclass")
- def list_traverse(self, *args: Any, **kwargs: Any
- ) -> Union[IterableList['TraversableIterableObj'],
- IterableList[Tuple[Union[None, 'TraversableIterableObj'], 'TraversableIterableObj']]]:
+ def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList['TraversableIterableObj']:
"""
:return: IterableList with the results of the traversal as produced by
traverse()
List objects must be IterableObj and Traversable e.g. Commit, Submodule"""
- out: Union[IterableList['TraversableIterableObj'],
- IterableList[Tuple[Union[None, 'TraversableIterableObj'], 'TraversableIterableObj']]]
-
- # def is_TraversableIterableObj(inp: Union['Traversable', IterableObj]) -> TypeGuard['TraversableIterableObj']:
- # return isinstance(self, TraversableIterableObj)
- # assert is_TraversableIterableObj(self), f"{type(self)}"
-
- self = cast('TraversableIterableObj', self)
- out = IterableList(self._id_attribute_)
- out.extend(self.traverse(*args, **kwargs)) # type: ignore
+ def is_TraversableIterableObj(inp: 'Traversable') -> TypeGuard['TraversableIterableObj']:
+ # return isinstance(self, TraversableIterableObj)
+ # Can it be anythin else?
+ return isinstance(self, Traversable)
+
+ assert is_TraversableIterableObj(self), f"{type(self)}"
+ out: IterableList['TraversableIterableObj'] = IterableList(self._id_attribute_)
+ out.extend(self.traverse(*args, **kwargs))
return out
def traverse(self,