diff options
author | Yobmod <yobmod@gmail.com> | 2021-05-20 20:44:53 +0100 |
---|---|---|
committer | Yobmod <yobmod@gmail.com> | 2021-05-20 20:44:53 +0100 |
commit | 5402a166a4971512f9d513bf36159dead9672ae9 (patch) | |
tree | ced858bf72b6c9f0f594d54f44db04e91abc0da6 | |
parent | c242b55d7c64ee43405f8b335c762bcf92189d38 (diff) | |
download | gitpython-5402a166a4971512f9d513bf36159dead9672ae9.tar.gz |
Add types to objects _get_intermediate_items()
-rw-r--r-- | git/objects/commit.py | 12 | ||||
-rw-r--r-- | git/objects/submodule/base.py | 6 | ||||
-rw-r--r-- | git/objects/tree.py | 8 | ||||
-rw-r--r-- | git/objects/util.py | 64 |
4 files changed, 68 insertions, 22 deletions
diff --git a/git/objects/commit.py b/git/objects/commit.py index 45e6d772..6d3f0bac 100644 --- a/git/objects/commit.py +++ b/git/objects/commit.py @@ -4,6 +4,7 @@ # This module is part of GitPython and is released under # the BSD License: http://www.opensource.org/licenses/bsd-license.php +from typing import Tuple, Union from gitdb import IStream from git.util import ( hex_to_bin, @@ -70,7 +71,8 @@ class Commit(base.Object, Iterable, Diffable, Traversable, Serializable): def __init__(self, repo, binsha, tree=None, author=None, authored_date=None, author_tz_offset=None, committer=None, committed_date=None, committer_tz_offset=None, - message=None, parents=None, encoding=None, gpgsig=None): + message=None, parents: Union[Tuple['Commit', ...], None] = None, + encoding=None, gpgsig=None): """Instantiate a new Commit. All keyword arguments taking None as default will be implicitly set on first query. @@ -133,7 +135,7 @@ class Commit(base.Object, Iterable, Diffable, Traversable, Serializable): self.gpgsig = gpgsig @classmethod - def _get_intermediate_items(cls, commit): + def _get_intermediate_items(cls, commit: 'Commit') -> Tuple['Commit', ...]: # type: ignore return commit.parents @classmethod @@ -477,7 +479,7 @@ class Commit(base.Object, Iterable, Diffable, Traversable, Serializable): readline = stream.readline self.tree = Tree(self.repo, hex_to_bin(readline().split()[1]), Tree.tree_id << 12, '') - self.parents = [] + self.parents_list = [] # List['Commit'] next_line = None while True: parent_line = readline() @@ -485,9 +487,9 @@ class Commit(base.Object, Iterable, Diffable, Traversable, Serializable): next_line = parent_line break # END abort reading parents - self.parents.append(type(self)(self.repo, hex_to_bin(parent_line.split()[-1].decode('ascii')))) + self.parents_list.append(type(self)(self.repo, hex_to_bin(parent_line.split()[-1].decode('ascii')))) # END for each parent line - self.parents = tuple(self.parents) + self.parents = tuple(self.parents_list) # type: Tuple['Commit', ...] # we don't know actual author encoding before we have parsed it, so keep the lines around author_line = next_line diff --git a/git/objects/submodule/base.py b/git/objects/submodule/base.py index e3be1a72..b03fa22a 100644 --- a/git/objects/submodule/base.py +++ b/git/objects/submodule/base.py @@ -3,6 +3,7 @@ from io import BytesIO import logging import os import stat +from typing import List from unittest import SkipTest import uuid @@ -134,10 +135,11 @@ class Submodule(IndexObject, Iterable, Traversable): super(Submodule, self)._set_cache_(attr) # END handle attribute name - def _get_intermediate_items(self, item): + @classmethod + def _get_intermediate_items(cls, item: 'Submodule') -> List['Submodule']: # type: ignore """:return: all the submodules of our module repository""" try: - return type(self).list_items(item.module()) + return cls.list_items(item.module()) except InvalidGitRepositoryError: return [] # END handle intermediate items diff --git a/git/objects/tree.py b/git/objects/tree.py index 68e98329..65c9be4c 100644 --- a/git/objects/tree.py +++ b/git/objects/tree.py @@ -3,6 +3,7 @@ # # This module is part of GitPython and is released under # the BSD License: http://www.opensource.org/licenses/bsd-license.php +from typing import Iterable, Iterator, Tuple, Union, cast from git.util import join_path import git.diff as diff from git.util import to_bin_sha @@ -182,8 +183,10 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): super(Tree, self).__init__(repo, binsha, mode, path) @classmethod - def _get_intermediate_items(cls, index_object): + def _get_intermediate_items(cls, index_object: 'Tree', # type: ignore + ) -> Tuple['Tree', ...]: if index_object.type == "tree": + index_object = cast('Tree', index_object) return tuple(index_object._iter_convert_to_object(index_object._cache)) return () @@ -196,7 +199,8 @@ class Tree(IndexObject, diff.Diffable, util.Traversable, util.Serializable): super(Tree, self)._set_cache_(attr) # END handle attribute - def _iter_convert_to_object(self, iterable): + def _iter_convert_to_object(self, iterable: Iterable[Tuple[bytes, int, str]] + ) -> Iterator[Union[Blob, 'Tree', Submodule]]: """Iterable yields tuples of (binsha, mode, name), which will be converted to the respective object representation""" for binsha, mode, name in iterable: diff --git a/git/objects/util.py b/git/objects/util.py index fdc1406b..88183567 100644 --- a/git/objects/util.py +++ b/git/objects/util.py @@ -4,6 +4,8 @@ # This module is part of GitPython and is released under # the BSD License: http://www.opensource.org/licenses/bsd-license.php """Module for general utility functions""" + + from git.util import ( IterableList, Actor @@ -18,9 +20,10 @@ import calendar from datetime import datetime, timedelta, tzinfo # typing ------------------------------------------------------------ -from typing import Any, IO, TYPE_CHECKING, Tuple, Type, Union, cast +from typing import Any, Callable, IO, Iterator, Sequence, TYPE_CHECKING, Tuple, Type, Union, cast, overload if TYPE_CHECKING: + from .submodule.base import Submodule from .commit import Commit from .blob import Blob from .tag import TagObject @@ -115,7 +118,7 @@ def verify_utctz(offset: str) -> str: class tzoffset(tzinfo): - + def __init__(self, secs_west_of_utc: float, name: Union[None, str] = None) -> None: self._offset = timedelta(seconds=-secs_west_of_utc) self._name = name or 'fixed' @@ -275,29 +278,61 @@ class Traversable(object): """Simple interface to perform depth-first or breadth-first traversals into one direction. Subclasses only need to implement one function. - Instances of the Subclass must be hashable""" + Instances of the Subclass must be hashable + + Defined subclasses = [Commit, Tree, SubModule] + """ __slots__ = () + @overload + @classmethod + def _get_intermediate_items(cls, item: 'Commit') -> Tuple['Commit', ...]: + ... + + @overload @classmethod - def _get_intermediate_items(cls, item): + def _get_intermediate_items(cls, item: 'Submodule') -> Tuple['Submodule', ...]: + ... + + @overload + @classmethod + def _get_intermediate_items(cls, item: 'Tree') -> Tuple['Tree', ...]: + ... + + @overload + @classmethod + def _get_intermediate_items(cls, item: 'Traversable') -> Tuple['Traversable', ...]: + ... + + @classmethod + def _get_intermediate_items(cls, item: 'Traversable' + ) -> Sequence['Traversable']: """ Returns: - List of items connected to the given item. + Tuple of items connected to the given item. Must be implemented in subclass + + class Commit:: (cls, Commit) -> Tuple[Commit, ...] + class Submodule:: (cls, Submodule) -> Iterablelist[Submodule] + class Tree:: (cls, Tree) -> Tuple[Tree, ...] """ raise NotImplementedError("To be implemented in subclass") - def list_traverse(self, *args, **kwargs): + def list_traverse(self, *args: Any, **kwargs: Any) -> IterableList: """ :return: IterableList with the results of the traversal as produced by traverse()""" - out = IterableList(self._id_attribute_) + out = IterableList(self._id_attribute_) # type: ignore[attr-defined] # defined in sublcasses out.extend(self.traverse(*args, **kwargs)) return out - def traverse(self, predicate=lambda i, d: True, - prune=lambda i, d: False, depth=-1, branch_first=True, - visit_once=True, ignore_self=1, as_edge=False): + def traverse(self, + predicate: Callable[[object, int], bool] = lambda i, d: True, + prune: Callable[[object, int], bool] = lambda i, d: False, + depth: int = -1, + branch_first: bool = True, + visit_once: bool = True, ignore_self: int = 1, as_edge: bool = False + ) -> Union[Iterator['Traversable'], Iterator[Tuple['Traversable', 'Traversable']]]: """:return: iterator yielding of items found when traversing self :param predicate: f(i,d) returns False if item i at depth d should not be included in the result @@ -329,13 +364,16 @@ class Traversable(object): destination, i.e. tuple(src, dest) with the edge spanning from source to destination""" visited = set() - stack = Deque() + stack = Deque() # type: Deque[Tuple[int, Traversable, Union[Traversable, None]]] stack.append((0, self, None)) # self is always depth level 0 - def addToStack(stack, item, branch_first, depth): + def addToStack(stack: Deque[Tuple[int, 'Traversable', Union['Traversable', None]]], + item: 'Traversable', + branch_first: bool, + depth) -> None: lst = self._get_intermediate_items(item) if not lst: - return + return None if branch_first: stack.extendleft((depth, i, item) for i in lst) else: |