I'm working on a project where I need to create a directed weighted graph in Python that allows parallel edges with different weights between nodes. I am using the networkx library and Matplotlib
for visualization.
My goal is to:
import random
import networkx as nx
import matplotlib.pyplot as plt
def create_graph(n_nodes, alpha = 0.5):
G = nx.MultiDiGraph()
G.add_nodes_from(range(n_nodes))
for i in range(n_nodes):
for j in range(i+1,n_nodes):
if random.random() < alpha:
weight=random.randint(1,10)
G.add_edge(i, j, weight=weight)
if random.random() < alpha:
weight=random.randint(1,10)
G.add_edge(j, i, weight=weight)
return G
def display_graph(G):
pos = nx.spring_layout(G)
weight_labels = nx.get_edge_attributes(G, 'weight')
nx.draw(G, pos, with_labels=True, node_color='skyblue', edge_color='gray', node_size=700)
nx.draw_networkx_edge_labels(G, pos, edge_labels=weight_labels)
plt.show()
n_nodes = 5
G = create_graph(n_nodes, alpha = 0.5)
display_graph(G)
However, when I try to visualize the graph with edge labels, I got this error message:
networkx.exception.NetworkXError: draw_networkx_edge_labels does not support multiedges.
This error occurs when I try to display edge labels for a MultiDiGraph, and it seems like the draw_networkx_edge_labels
function doesn't support parallel edges.
networkx
?I'd appreciate any guidance or examples to help me achieve this. Thank you in advance!
A core issue of Networkx's drawing utilities is that they are separated into different functions. For labeling multi-graph edges, the function that determines the location of the edge label (draw_networkx_edge_labels
) requires knowledge of the edge path; however, the edge path is computed in draw_networkx_edges
and there is no cross-talk between the two.
To work around this and other design issues of the Networkx drawing utilities, I wrote a drop-in replacement called Netgraph
. Multi-graph support, however, is pretty new so you have to install from the dev branch (stable releases are distributed via pip and conda-forge):
pip install https://github.com/paulbrodersen/netgraph/archive/dev.zip
import random
import networkx as nx
import matplotlib.pyplot as plt
from netgraph import MultiGraph
def create_graph(n_nodes, alpha = 0.5):
G = nx.MultiDiGraph()
G.add_nodes_from(range(n_nodes))
for i in range(n_nodes):
for j in range(i+1,n_nodes):
if random.random() < alpha:
weight=random.randint(1,10)
G.add_edge(i, j, weight=weight)
if random.random() < alpha:
weight=random.randint(1,10)
G.add_edge(j, i, weight=weight)
return G
G = create_graph(5)
MultiGraph(
G,
node_labels=True,
node_color="skyblue",
edge_color="gray",
edge_labels=nx.get_edge_attributes(G, 'weight'),
edge_label_fontdict=dict(fontsize=8),
arrows=True,
)
plt.show()