Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Shading background based on groups above/below a line

Suppose I have a scatterplot with some kind of line (least squares regression line, knn regression line, etc.) through it, like this. enter image description here I want to shade the upper region of the plot reddish, and the lower region of the plot blueish, to give an indication of how my line is doing as a classifier for the points. Similar to my mimic example with this effect is this plot from Elements of Statistical Learning (Hastie et al), (Chapter 2, page 13).

enter image description here

How can I achieve this effect with Matplotlib?


I know how to set rectangular regions of a plot to be different colors with axhspan and axvspan (see this answer), but have been struggling to set different plot colors based on regions above and below a line.

Code to replicate my current mock plot

import numpy as np
import matplotlib.pyplot as plt

plt.style.use('seaborn-notebook')

np.random.seed(17)
grp1_x = np.random.normal(1, 1, 100)
grp1_y = np.random.normal(3, 1, 100)

grp2_x = np.random.normal(1.2, 1, 100)
grp2_y = np.random.normal(1.2, 1, 100)

########################################
## least squares plot

plt.scatter(grp1_x, grp1_y,
            lw         = 1,
            facecolors = 'none',
            edgecolors = 'firebrick')
plt.scatter(grp2_x, grp2_y,
            lw         = 1,
            facecolors = 'none',
            edgecolors = 'steelblue')
plt.tick_params(
    axis        = 'both',
    which       = 'both',
    bottom      = 'off',
    top         = 'off',
    labelbottom = 'off',
    right       = 'off',
    left        = 'off',
    labelleft   = 'off')

full_x = np.concatenate([grp1_x, grp2_x])
full_y = np.concatenate([grp1_y, grp2_y])
m, c = np.linalg.lstsq(np.vstack([full_x,
                                  np.ones(full_x.size)]).T,
                       full_y)[0]
plt.plot(full_x, m*full_x + c, color='black')
plt.show()
like image 383
Eric Hansen Avatar asked Mar 08 '23 07:03

Eric Hansen


1 Answers

First I would recommend sorting the x values, such that the line looks smooth.

x = np.sort(full_x)
plt.plot(x, m*x + c, color='black')

Then you can use fill_between to fill the region above (below) the line up to (from) the upper (lower) plot limits.

xlim=np.array(plt.gca().get_xlim())
ylim=np.array(plt.gca().get_ylim())
plt.fill_between(xlim, y1=m*xlim + c, y2=[ylim[0],ylim[0]], 
                 color="#e0eaf3", zorder=0 )
plt.fill_between(xlim, y1=m*xlim + c, y2=[ylim[1],ylim[1]], 
                 color="#fae4e4", zorder=0 )
plt.margins(0)

enter image description here

Or use some hatching for the background:

fb1 = plt.fill_between(xlim, y1=m*xlim + c, y2=[ylim[0],ylim[0]], 
                 facecolor="w", edgecolor="#e0eaf3", zorder=0 )
fb1.set_hatch("//")
fb2 = plt.fill_between(xlim, y1=m*xlim + c, y2=[ylim[1],ylim[1]], 
                  facecolor="w", edgecolor="#fae4e4", zorder=0 )
fb2.set_hatch("\\\\")

enter image description here

like image 120
ImportanceOfBeingErnest Avatar answered Mar 11 '23 06:03

ImportanceOfBeingErnest