Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Add legend to networks plot to explain colouring of nodes

I have a plot of a networkx graph in which edge-color depends on the weights assigned to the respective edges using the following code (with a_netw the nx.Graph):

a_netw_edges = a_netw.edges()
a_netw_weights = [a_netw[source][dest]['weight'] for source, dest in a_netw_edges]
a_netw_colors = [plt.cm.Blues(weight*15) for weight in a_netw_weights]
nx.draw_networkx(a_netw, edges=a_netw_edges, width=1, edge_color=a_netw_colors)

To this graph I would like to add a legend that makes the connection between the weights and the colours explicit; like in a heatmap that uses pcolor.

While I have a rough idea of how to start:

fig, axes = plt.subplots(nrows=2)
nx.draw_networkx(a_netw, edges=a_netw_edges, width=1, edge_color=a_netw_colors, ax=axes[0])
axes[0].get_xaxis().set_visible(False)
axes[0].get_yaxis().set_visible(False)
gradient = np.linspace(0, 1, 256)
gradient = np.vstack((gradient, gradient))
axes[1].imshow(gradient, aspect=3, cmap=plt.cm.Blues)
axes[1].get_yaxis().set_visible(False)
plt.tight_layout()

I have no idea how to do the following steps:

  1. Add the correct ticks on the relevant axis to get the connection with the weights.
  2. Draw it vertically instead of horizontally.
like image 687
Patrick Allo Avatar asked Jul 22 '16 20:07

Patrick Allo


2 Answers

I suggest you use the colorbar() command as shown below. I am providing an example Graph, see if it makes sense?

enter image description here

import networkx as nx
import matplotlib.pyplot as plt

#generate a graph with weights
a_netw=nx.Graph()
a_netw.add_edge('a','b',weight=6)
a_netw.add_edge('a','c',weight=2)
a_netw.add_edge('c','d',weight=1)
a_netw.add_edge('c','e',weight=7)
a_netw.add_edge('c','f',weight=9)
a_netw.add_edge('a','d',weight=3)

#creating a color list for each edge based on weight

a_netw_edges = a_netw.edges()
a_netw_weights = [a_netw[source][dest]['weight'] for source, dest in a_netw_edges]

#scale weights in range 0-1 before assigning color 
maxWeight=float(max(a_netw_weights))
a_netw_colors = [plt.cm.Blues(weight/maxWeight) for weight in a_netw_weights]


#suppress plotting for the following dummy heatmap
plt.ioff()

#multiply all tuples in color list by scale factor
colors_unscaled=[tuple(map(lambda x: maxWeight*x, y)) for y in a_netw_colors]
#generate a 'dummy' heatmap using the edgeColors as substrate for colormap
heatmap = plt.pcolor(colors_unscaled,cmap=plt.cm.Blues)

#re-enable plotting
plt.ion()

fig,axes = plt.subplots()
nx.draw_networkx(a_netw, edges=a_netw_edges, width=10, edge_color=a_netw_colors, ax=axes)
axes.get_xaxis().set_visible(False)
axes.get_yaxis().set_visible(False)

#add colorbar
cbar = plt.colorbar(heatmap)
cbar.ax.set_ylabel('edge weight',labelpad=15,rotation=270)
like image 127
jlarsch Avatar answered Sep 24 '22 16:09

jlarsch


This answer avoids drawing a dummy heatmap for the colorbar. Instead, it passes the return value of nx.draw_networkx_edges to plt.colorbar to generate a colorbar.

import networkx as nx
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm

%matplotlib inline

print('networkx version: {}'.format(nx.__version__))
print('matplotlib version: {}'.format(matplotlib.__version__))
# create a graph using code from @jlarsch's answer
graph = nx.Graph()
graph.add_edge('a','b',weight=6)
graph.add_edge('a','c',weight=2)
graph.add_edge('c','d',weight=1)
graph.add_edge('c','e',weight=7)
graph.add_edge('c','f',weight=9)
graph.add_edge('a','d',weight=3)
# compute spring layout positions for the graph
pos = nx.spring_layout(graph, random_state=123)
# nx.draw_networkx_edges returns an instance of
# `matplotlib.collections.LineCollection`, which
# could be passed to plt.colorbar to generate
# color bar
# the list comprehension to obtain edge weights is from @Aric's answer at 
# https://stackoverflow.com/a/25651827/4638182
mcl = nx.draw_networkx_edges(
    graph, pos, edge_cmap=cm.Blues, width=10,
    edge_color=[graph[u][v]['weight'] for u, v in graph.edges])
nx.draw_networkx_nodes(graph, pos)
nx.draw_networkx_labels(graph, pos)
plt.colorbar(mcl)
plt.show()

enter image description here

Similar procedure also works for generating colorbar for nodes.

import networkx as nx
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline

print("networkx version: {}".format(nx.__version__))
print("matplotlib version: {}".format(matplotlib.__version__))
# create a graph using code from @jlarsch's answer
graph = nx.Graph()
graph.add_edge('a','b',weight=6)
graph.add_edge('a','c',weight=2)
graph.add_edge('c','d',weight=1)
graph.add_edge('c','e',weight=7)
graph.add_edge('c','f',weight=9)
graph.add_edge('a','d',weight=3)
n_nodes = graph.number_of_nodes()
pos = nx.spring_layout(graph, random_state=123)
nx.draw_networkx_edges(graph, pos)
nx.draw_networkx_labels(graph, pos)
mcp = nx.draw_networkx_nodes(graph, pos,
                             node_color=list(range(n_nodes)),
                             cmap='Blues')
plt.colorbar(mcp)
plt.show()

enter image description here

like image 30
Logstar Avatar answered Sep 21 '22 16:09

Logstar