summaryrefslogtreecommitdiff
path: root/networkx/classes/tests/dispatch_interface.py
blob: ded79b363898cce6b6bb6c6af28bad5754519b56 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# This file contains utilities for testing the dispatching feature

# A full test of all dispatchable algorithms is performed by
# modifying the pytest invocation and setting an environment variable
# NETWORKX_GRAPH_CONVERT=nx-loopback pytest
# This is comprehensive, but only tests the `test_override_dispatch`
# function in networkx.classes.backends.

# To test the `_dispatch` function directly, several tests scattered throughout
# NetworkX have been augmented to test normal and dispatch mode.
# Searching for `dispatch_interface` should locate the specific tests.

import networkx as nx
from networkx import DiGraph, Graph, MultiDiGraph, MultiGraph, PlanarEmbedding


class LoopbackGraph(Graph):
    __networkx_plugin__ = "nx-loopback"


class LoopbackDiGraph(DiGraph):
    __networkx_plugin__ = "nx-loopback"


class LoopbackMultiGraph(MultiGraph):
    __networkx_plugin__ = "nx-loopback"


class LoopbackMultiDiGraph(MultiDiGraph):
    __networkx_plugin__ = "nx-loopback"


class LoopbackPlanarEmbedding(PlanarEmbedding):
    __networkx_plugin__ = "nx-loopback"


def convert(graph):
    if isinstance(graph, PlanarEmbedding):
        return LoopbackPlanarEmbedding(graph)
    if isinstance(graph, MultiDiGraph):
        return LoopbackMultiDiGraph(graph)
    if isinstance(graph, MultiGraph):
        return LoopbackMultiGraph(graph)
    if isinstance(graph, DiGraph):
        return LoopbackDiGraph(graph)
    if isinstance(graph, Graph):
        return LoopbackGraph(graph)
    raise TypeError(f"Unsupported type of graph: {type(graph)}")


class LoopbackDispatcher:
    non_toplevel = {
        "inter_community_edges": nx.community.quality.inter_community_edges,
        "is_tournament": nx.algorithms.tournament.is_tournament,
        "mutual_weight": nx.algorithms.structuralholes.mutual_weight,
        "score_sequence": nx.algorithms.tournament.score_sequence,
        "tournament_matrix": nx.algorithms.tournament.tournament_matrix,
    }

    def __getattr__(self, item):
        # Return the original, undecorated NetworkX algorithm
        if hasattr(nx, item):
            return getattr(nx, item)._orig_func
        if item in self.non_toplevel:
            return self.non_toplevel[item]._orig_func
        raise AttributeError(item)

    @staticmethod
    def convert_from_nx(graph, weight=None, *, name=None):
        return graph

    @staticmethod
    def convert_to_nx(obj, *, name=None):
        return obj

    @staticmethod
    def on_start_tests(items):
        # Verify that items can be xfailed
        for item in items:
            assert hasattr(item, "add_marker")


dispatcher = LoopbackDispatcher()