summaryrefslogtreecommitdiff
path: root/networkx/algorithms/tree/mst.py
diff options
context:
space:
mode:
Diffstat (limited to 'networkx/algorithms/tree/mst.py')
-rw-r--r--networkx/algorithms/tree/mst.py148
1 files changed, 113 insertions, 35 deletions
diff --git a/networkx/algorithms/tree/mst.py b/networkx/algorithms/tree/mst.py
index 09cd7651..5345db99 100644
--- a/networkx/algorithms/tree/mst.py
+++ b/networkx/algorithms/tree/mst.py
@@ -23,21 +23,42 @@ import networkx as nx
from networkx.utils import UnionFind, not_implemented_for
-def kruskal_mst_edges(G, minimum, weight='weight', data=True):
+def kruskal_mst_edges(G, minimum, weight='weight', keys=True, data=True):
subtrees = UnionFind()
- edges = sorted(G.edges(data=True), key=lambda t: t[2].get(weight, 1),
- reverse=not minimum)
-
- for u, v, d in edges:
- if subtrees[u] != subtrees[v]:
- if data:
- yield (u, v, d)
- else:
- yield (u, v)
- subtrees.union(u, v)
-
-
-def prim_mst_edges(G, minimum, weight='weight', data=True):
+ if G.is_multigraph():
+ edges = G.edges(keys=True, data=True)
+ else:
+ edges = G.edges(data=True)
+ getweight = lambda t: t[-1].get(weight, 1)
+ edges = sorted(edges, key=getweight, reverse=not minimum)
+ is_multigraph = G.is_multigraph()
+ # Multigraphs need to handle edge keys in addition to edge data.
+ if is_multigraph:
+ for u, v, k, d in edges:
+ if subtrees[u] != subtrees[v]:
+ if keys:
+ if data:
+ yield (u, v, k, d)
+ else:
+ yield (u, v, k)
+ else:
+ if data:
+ yield (u, v, d)
+ else:
+ yield (u, v)
+ subtrees.union(u, v)
+ else:
+ for u, v, d in edges:
+ if subtrees[u] != subtrees[v]:
+ if data:
+ yield (u, v, d)
+ else:
+ yield (u, v)
+ subtrees.union(u, v)
+
+
+def prim_mst_edges(G, minimum, weight='weight', keys=True, data=True):
+ is_multigraph = G.is_multigraph()
push = heappush
pop = heappop
@@ -52,24 +73,44 @@ def prim_mst_edges(G, minimum, weight='weight', data=True):
u = nodes.pop(0)
frontier = []
visited = [u]
- for u, v in G.edges(u):
- push(frontier, (G[u][v].get(weight, 1) * sign, next(c), u, v))
-
+ if is_multigraph:
+ for u, v, k, d in G.edges(u, keys=True, data=True):
+ push(frontier, (d.get(weight, 1) * sign, next(c), u, v, k))
+ else:
+ for u, v, d in G.edges(u, data=True):
+ push(frontier, (d.get(weight, 1) * sign, next(c), u, v))
while frontier:
- W, _, u, v = pop(frontier)
+ if is_multigraph:
+ W, _, u, v, k = pop(frontier)
+ else:
+ W, _, u, v = pop(frontier)
if v in visited:
continue
visited.append(v)
nodes.remove(v)
- for v, w in G.edges(v):
- if w in visited:
- continue
- push(frontier, (G[v][w].get(weight, 1) * sign, next(c), v, w))
-
- if data:
- yield u, v, G[u][v]
+ if is_multigraph:
+ for _, w, k2, d2 in G.edges(v, keys=True, data=True):
+ if w in visited:
+ continue
+ new_weight = d2.get(weight, 1) * sign
+ push(frontier, (new_weight, next(c), v, w, k2))
+ else:
+ for _, w, d2 in G.edges(v, data=True):
+ if w in visited:
+ continue
+ new_weight = d2.get(weight, 1) * sign
+ push(frontier, (new_weight, next(c), v, w))
+ # Multigraphs need to handle edge keys in addition to edge data.
+ if is_multigraph and keys:
+ if data:
+ yield u, v, k, G[u][v]
+ else:
+ yield u, v, k
else:
- yield u, v
+ if data:
+ yield u, v, G[u][v]
+ else:
+ yield u, v
ALGORITHMS = {
'kruskal': kruskal_mst_edges,
@@ -78,17 +119,19 @@ ALGORITHMS = {
@not_implemented_for('directed')
-def _spanning_edges(G, minimum, algorithm='kruskal', weight='weight', data=True):
+def _spanning_edges(G, minimum, algorithm='kruskal', weight='weight',
+ keys=True, data=True):
try:
algo = ALGORITHMS[algorithm]
except KeyError:
msg = '{} is not a valid choice for an algorithm.'.format(algorithm)
raise ValueError(msg)
- return algo(G, minimum=minimum, weight=weight, data=data)
+ return algo(G, minimum=minimum, weight=weight, keys=keys, data=data)
-def minimum_spanning_edges(G, algorithm='kruskal', weight='weight', data=True):
+def minimum_spanning_edges(G, algorithm='kruskal', weight='weight', keys=True,
+ data=True):
"""Generate edges in a minimum spanning forest of an undirected
weighted graph.
@@ -109,14 +152,30 @@ def minimum_spanning_edges(G, algorithm='kruskal', weight='weight', data=True):
weight : string
Edge data key to use for weight (default 'weight').
+ keys : bool
+ Whether to yield edge key in multigraphs in addition to the
+ edge. If ``G`` is not a multigraph, this is ignored.
+
data : bool, optional
If True yield the edge data along with the edge.
Returns
-------
edges : iterator
- A generator that produces edges in the minimum spanning tree.
- The edges are three-tuples (u,v,w) where w is the weight.
+ An iterator over tuples representing edges in a minimum spanning
+ tree of ``G``.
+
+ If ``G`` is a multigraph and both ``keys`` and ``data`` are
+ ``True``, then the tuples are four-tuples of the form ``(u, v, k,
+ w)``, where ``(u, v)`` is an edge, ``k`` is the edge key
+ identifying the particular edge joining ``u`` with ``v``, and
+ ``w`` is the weight of the edge. If ``keys`` is ``True`` but
+ ``data`` is ``False``, the tuples are three-tuples of the form
+ ``(u, v, k)``.
+
+ If ``G`` is not a multigraph, the tuples are of the form ``(u, v,
+ w)`` if ``data`` is ``True`` or ``(u, v)`` if ``data`` is
+ ``False``.
Examples
--------
@@ -150,7 +209,7 @@ def minimum_spanning_edges(G, algorithm='kruskal', weight='weight', data=True):
http://www.ics.uci.edu/~eppstein/PADS/
"""
return _spanning_edges(G, minimum=True, algorithm=algorithm,
- weight=weight, data=data)
+ weight=weight, keys=keys, data=data)
def maximum_spanning_edges(G, algorithm='kruskal', weight='weight', data=True):
@@ -174,14 +233,30 @@ def maximum_spanning_edges(G, algorithm='kruskal', weight='weight', data=True):
weight : string
Edge data key to use for weight (default 'weight').
+ keys : bool
+ Whether to yield edge key in multigraphs in addition to the
+ edge. If ``G`` is not a multigraph, this is ignored.
+
data : bool, optional
If True yield the edge data along with the edge.
Returns
-------
edges : iterator
- A generator that produces edges in the maximum spanning tree.
- The edges are three-tuples (u,v,w) where w is the weight.
+ An iterator over tuples representing edges in a maximum spanning
+ tree of ``G``.
+
+ If ``G`` is a multigraph and both ``keys`` and ``data`` are
+ ``True``, then the tuples are four-tuples of the form ``(u, v, k,
+ w)``, where ``(u, v)`` is an edge, ``k`` is the edge key
+ identifying the particular edge joining ``u`` with ``v``, and
+ ``w`` is the weight of the edge. If ``keys`` is ``True`` but
+ ``data`` is ``False``, the tuples are three-tuples of the form
+ ``(u, v, k)``.
+
+ If ``G`` is not a multigraph, the tuples are of the form ``(u, v,
+ w)`` if ``data`` is ``True`` or ``(u, v)`` if ``data`` is
+ ``False``.
Examples
--------
@@ -224,7 +299,10 @@ def _optimum_spanning_tree(G, algorithm, minimum, weight='weight'):
msg = '{} is not a valid choice for an algorithm.'.format(algorithm)
raise ValueError(msg)
- edges = algo(G, minimum=minimum, weight=weight, data=True)
+ # When creating the spanning tree, we can ignore the key used to
+ # identify multigraph edges, since a tree is guaranteed to have no
+ # multiedges. This is why we use `keys=False`.
+ edges = algo(G, minimum=minimum, weight=weight, keys=False, data=True)
T = nx.Graph(edges)
# Add isolated nodes