Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I create a seaborn regression plot with multiindex dataframe?

I have time series data which are multi-indexed on (Year, Month) as seen here:

print(df.index)
print(df)
MultiIndex(levels=[[2016, 2017], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]],
           labels=[[0, 0, 0, 0, 0, 0, 0, 0], [2, 3, 4, 5, 6, 7, 8, 9]],
           names=['Year', 'Month'])
            Value
Year Month            
2016 3       65.018150
     4       63.130035
     5       71.071254
     6       72.127967
     7       67.357795
     8       66.639228
     9       64.815232
     10      68.387698

I want to do very basic linear regression on these time series data. Because pandas.DataFrame.plot does not do any regression, I intend to use Seaborn to do my plotting.

I attempted to do this by using lmplot:

sns.lmplot(x=("Year", "Month"), y="Value", data=df, fit_reg=True) 

but I get an error:

TypeError: '>' not supported between instances of 'str' and 'tuple'

This is particularly interesting to me because all elements in df.index.levels[:] are of type numpy.int64, all elements in df.index.labels[:] are of type numpy.int8.

Why am I receiving this error? How can I resolve it?

like image 832
erip Avatar asked Sep 20 '25 12:09

erip


1 Answers

You can use reset_index to turn the dataframe's index into columns. Plotting DataFrames columns is then straight forward with seaborn.

As I guess the reason to use lmplot would be to show different regressions for different years (otherwise a regplot may be better suited), the "Year"column can be used as hue.

import numpy as np
import pandas as pd
import seaborn.apionly as sns
import matplotlib.pyplot as plt

iterables = [[2016, 2017], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]
index = pd.MultiIndex.from_product(iterables, names=['Year', 'Month'])
df = pd.DataFrame({"values":np.random.rand(24)}, index=index)

df2 = df.reset_index()  # or, df.reset_index(inplace=True) if df is not required otherwise 

g = sns.lmplot(x="Month", y="values", data=df2, hue="Year")

plt.show()

enter image description here

like image 196
ImportanceOfBeingErnest Avatar answered Sep 22 '25 07:09

ImportanceOfBeingErnest