R has a useful function pairs
that provides nice matrix of plots of pairwise connections between variables in a data set. The resulting plot looks similar to the following figure, copied from this blog post:
Is there any ready to use function based on python's matplolib? I have searched its gallery, but couldn't find anything that resembles what I need. Technically, this should be a simple task, but proper handling of all the possible cases, labels, titles, etc is very tedious.
UPDATE see below my answer with a quick and dirty approximation.
The line magic command %matplotlib inline enables the drawing of matplotlib figures in the IPython environment. Once this command is executed in any cell, then for the rest of the session, the matplotlib plots will appear directly below the cell in which the plot function was called.
Changing the Defaults: rcParams Each time Matplotlib loads, it defines a runtime configuration (rc) containing the default styles for every plot element you create. This configuration can be adjusted at any time using the plt.
Matplotlib has a function named annotate() to add text in a specific location in a plot. We need to specify annotate() function the text we want to annotate the plot with and the x and y co-ordinates for the location of the text.
Pandas
has a built in function scatter_matrix
(source code) which is something like this.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.randn(1000, 4), columns=['A','B','C','D'])
axes = pd.tools.plotting.scatter_matrix(df, alpha=0.2)
plt.tight_layout()
plt.savefig('scatter_matrix.png')
However it is pandas
specific (but could be used as a starting point).
There are some more R
like plots in pandas. Have a look at the docs.
Quick and dirty approximation to my needs:
def pair(data, labels=None):
""" Generate something similar to R `pair` """
nVariables = data.shape[1]
if labels is None:
labels = ['var%d'%i for i in range(nVariables)]
fig = pl.figure()
for i in range(nVariables):
for j in range(nVariables):
nSub = i * nVariables + j + 1
ax = fig.add_subplot(nVariables, nVariables, nSub)
if i == j:
ax.hist(data[:,i])
ax.set_title(labels[i])
else:
ax.plot(data[:,i], data[:,j], '.k')
return fig
The code above is hereby released into the public domain
The subplots
function in recent versions of matplotlib (at least 1.4) makes this a little bit easier:
def pairs(data, names):
"Quick&dirty scatterplot matrix"
d = len(data)
fig, axes = plt.subplots(nrows=d, ncols=d, sharex='col', sharey='row')
for i in range(d):
for j in range(d):
ax = axes[i,j]
if i == j:
ax.text(0.5, 0.5, names[i], transform=ax.transAxes,
horizontalalignment='center', verticalalignment='center',
fontsize=16)
else:
ax.scatter(data[j], data[i], s=10)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With