Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to draw a matrix sparsity pattern with color code in python?

I am using spy from matplotlib.pyplot to draw the sparsity pattern of a csc_matrix from scipy.sparse like this

>>> import scipy.sparse as sprs
>>> import matplotlib.pyplot as plt
>>> Matrix=sprs.rand(10,10, density=0.1, format='csc')
>>> plt.spy(Matrix)
>>> plt.show()

I want to do the same but give colors to the matrix elements according to their magnitude. Is there a simple way to make spy do this? If not, is there another way to do it?

like image 663
Fitzgerald Creen Avatar asked Jun 03 '14 11:06

Fitzgerald Creen


2 Answers

You could use imshow:

d=Matrix.todense()
plt.imshow(d,interpolation='none',cmap='binary')
plt.colorbar()

Gives:

enter image description here

like image 178
atomh33ls Avatar answered Oct 10 '22 02:10

atomh33ls


I had a similar problem. My solution: use a scatter plot with a color bar.

Basically I had a 100 by 100 sparse matrix, but I wanted to visualize all the points and the values of the points.

imshow is not a good solution for sparse matrices, as in my experience it might not show all the points! For me this was a serious issue.

spy is reliable for sparse matrices, but you can't add a colorbar.

So I tried to extract the non-zero values and plot them in a scatter plot and add a color bar based on the value of the non-zero elements.

Example below:

import numpy as np
import matplotlib.pyplot as plt

# example sparse matrix with different values inside
mat = np.zeros((20,20))
mat[[1,5,5,5,10,15],[1,4,5,6,10,15]] = [1,5,5,5,10,15]

fig,ax = plt.subplots(figsize=(8, 4), dpi= 80, facecolor='w', edgecolor='k')

# prepare x and y for scatter plot
plot_list = []
for rows,cols in zip(np.where(mat!=0)[0],np.where(mat!=0)[1]):
    plot_list.append([cols,rows,mat[rows,cols]])
plot_list = np.array(plot_list)

# scatter plot with color bar, with rows on y axis
plt.scatter(plot_list[:,0],plot_list[:,1],c=plot_list[:,2], s=50)
cb = plt.colorbar()

# full range for x and y axes
plt.xlim(0,mat.shape[1])
plt.ylim(0,mat.shape[0])
# invert y axis to make it similar to imshow
plt.gca().invert_yaxis()

Resulting figure inverted y axis

like image 25
rcojocaru Avatar answered Oct 10 '22 02:10

rcojocaru