Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Colorize the background of a seaborn plot using a column in dataframe

Question

How to shade or colorize the background of a seaborn plot using a column of a dataframe?

Code snippet

import numpy as np
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
fmri = sns.load_dataset("fmri")
fmri.sort_values('timepoint',inplace=True)
ax = sns.lineplot(x="timepoint", y="signal", data=fmri)
arr = np.ones(len(fmri))
arr[:300] = 0
arr[600:] = 2
fmri['background'] = arr

ax = sns.lineplot(x="timepoint", y="signal", hue="event", data=fmri)

Which produced this graph:
Actual output

Desired output

What I'd like to have, according to the value in the new column 'background' and any palette or user defined colors, something like this:

Desired output

like image 625
s.k Avatar asked Mar 27 '20 14:03

s.k


People also ask

How do I change the background color of my plot?

set_facecolor() method is used to change the inner background color of the plot. figure(facecolor='color') method is used to change the outer background color of the plot.

Can Seaborn use pandas DataFrame?

Seaborn provides an API on top of Matplotlib that offers sane choices for plot style and color defaults, defines simple high-level functions for common statistical plot types, and integrates with the functionality provided by Pandas DataFrame s.

What is Despine in Seaborn?

The despine() is a function that removes the spines from the right and upper portion of the plot by default. sns. despine(left = True) helps remove the spine from the left.


1 Answers

ax.axvspan() could work for you, assuming backgrounds don't overlap over timepoints.

import numpy as np
import seaborn as sns; sns.set()
import matplotlib.pyplot as plt
fmri = sns.load_dataset("fmri")
fmri.sort_values('timepoint',inplace=True)
arr = np.ones(len(fmri))
arr[:300] = 0
arr[600:] = 2
fmri['background'] = arr
fmri['background'] = fmri['background'].astype(int).astype(str).map(lambda x: 'C'+x)

ax = sns.lineplot(x="timepoint", y="signal", hue="event", data=fmri)
ranges = fmri.groupby('background')['timepoint'].agg(['min', 'max'])
for i, row in ranges.iterrows():
    ax.axvspan(xmin=row['min'], xmax=row['max'], facecolor=i, alpha=0.3)

enter image description here

like image 115
jjsantoso Avatar answered Oct 24 '22 17:10

jjsantoso