Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Seaborn FacetGrid user-defined plot function

Tags:

python

seaborn

In Seaborn, you can use FacetGrid to set up data-aware grids on which to plot. You can then use the map or map_dataframe methods to plot to those grids.

I am having trouble correctly specifying a user-defined plot function that works with map or map_dataframe. In this example I use the errorbar function in which I want to pass the error values as a 2xN array-like. In my example (taken from @mwaskom's answer here) the errors are symmetrical -- but imagine I have a situation where they are not.

In [255]:

from scipy import stats
tips_all = sns.load_dataset("tips")
tips_grouped = tips_all.groupby(["smoker", "size"])
tips = tips_grouped.mean()
tips["error_min"] = tips_grouped.total_bill.apply(stats.sem) * 1.96
tips["error_max"] = tips_grouped.total_bill.apply(stats.sem) * 1.96
tips.reset_index(inplace=True)
tips

Out[255]:
    smoker  size    total_bill  tip     error_min   error_max
0   No  1   8.660000    1.415000    2.763600    2.763600
1   No  2   15.342333   2.489000    0.919042    0.919042
2   No  3   21.009615   3.069231    2.680447    2.680447
3   No  4   27.769231   4.195769    3.303131    3.303131
4   No  5   30.576667   5.046667    11.620808   11.620808
5   No  6   34.830000   5.225000    9.194360    9.194360
6   Yes     1   5.825000    1.460000    5.399800    5.399800
7   Yes     2   17.955758   2.709545    1.805528    1.805528
8   Yes     3   28.191667   4.095000    6.898186    6.898186
9   Yes     4   30.609091   3.992727    5.150063    5.150063
10  Yes     5   29.305000   2.500000    2.263800    2.263800

Define my error bar function, that takes data and indexes the error columns to produce the 2xN array:

In [256]:
def my_errorbar(*args, **kwargs):
    data = kwargs['data']
    errors = np.vstack([data['error_min'], 
                        data['error_max']])
    print(errors)
    plt.errorbar(data[args[0]], 
                 data[args[1]], 
                 yerr=errors,
                 **kwargs);  

Call using map_dataframe (because my function gets the data as a kwarg):

In [257]:

g = sns.FacetGrid(tips, col="smoker", size=5)
g.map_dataframe(my_errorbar, "size", "total_bill", marker="o")

[[  2.7636       0.9190424    2.68044722   3.30313068  11.62080751
    9.19436049]
 [  2.7636       0.9190424    2.68044722   3.30313068  11.62080751
    9.19436049]]

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-257-dc8b35ec70ec> in <module>()
      1 g = sns.FacetGrid(tips, col="smoker", size=5)
----> 2 g.map_dataframe(my_errorbar, "size", "total_bill", marker="o")

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/seaborn/axisgrid.py in map_dataframe(self, func, *args, **kwargs)
    509 
    510             # Draw the plot
--> 511             self._facet_plot(func, ax, args, kwargs)
    512 
    513         # Finalize the annotations and layout

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/seaborn/axisgrid.py in _facet_plot(self, func, ax, plot_args, plot_kwargs)
    527 
    528         # Draw the plot
--> 529         func(*plot_args, **plot_kwargs)
    530 
    531         # Sort out the supporting information

<ipython-input-256-62202c841233> in my_errorbar(*args, **kwargs)
      9                  data[args[1]],
     10                  yerr=errors,
---> 11                  **kwargs);    
     12 
     13 

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/matplotlib/pyplot.py in errorbar(x, y, yerr, xerr, fmt, ecolor, elinewidth, capsize, barsabove, lolims, uplims, xlolims, xuplims, errorevery, capthick, hold, **kwargs)
   2764                           barsabove=barsabove, lolims=lolims, uplims=uplims,
   2765                           xlolims=xlolims, xuplims=xuplims,
-> 2766                           errorevery=errorevery, capthick=capthick, **kwargs)
   2767         draw_if_interactive()
   2768     finally:

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/matplotlib/axes/_axes.py in errorbar(self, x, y, yerr, xerr, fmt, ecolor, elinewidth, capsize, barsabove, lolims, uplims, xlolims, xuplims, errorevery, capthick, **kwargs)
   2859 
   2860         if not barsabove and plot_line:
-> 2861             l0, = self.plot(x, y, fmt, **kwargs)
   2862 
   2863         if ecolor is None:

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/matplotlib/axes/_axes.py in plot(self, *args, **kwargs)
   1371         lines = []
   1372 
-> 1373         for line in self._get_lines(*args, **kwargs):
   1374             self.add_line(line)
   1375             lines.append(line)

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/matplotlib/axes/_base.py in _grab_next_args(self, *args, **kwargs)
    302                 return
    303             if len(remaining) <= 3:
--> 304                 for seg in self._plot_args(remaining, kwargs):
    305                     yield seg
    306                 return

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/matplotlib/axes/_base.py in _plot_args(self, tup, kwargs)
    290         ncx, ncy = x.shape[1], y.shape[1]
    291         for j in xrange(max(ncx, ncy)):
--> 292             seg = func(x[:, j % ncx], y[:, j % ncy], kw, kwargs)
    293             ret.append(seg)
    294         return ret

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/matplotlib/axes/_base.py in _makeline(self, x, y, kw, kwargs)
    242                             **kw
    243                             )
--> 244         self.set_lineprops(seg, **kwargs)
    245         return seg
    246 

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/matplotlib/axes/_base.py in set_lineprops(self, line, **kwargs)
    184                 raise TypeError('There is no line property "%s"' % key)
    185             func = getattr(line, funcName)
--> 186             func(val)
    187 
    188     def set_patchprops(self, fill_poly, **kwargs):

/Users/x/miniconda3/envs/default/lib/python3.4/site-packages/matplotlib/lines.py in set_data(self, *args)
    557         """
    558         if len(args) == 1:
--> 559             x, y = args[0]
    560         else:
    561             x, y = args

ValueError: too many values to unpack (expected 2)

I don't understand the reason for the failure here. Note that the plot function gets something, because a plot of the first grid of errorbars is produced. I assume I'm not passing the **kwargs dictionary on correctly.

In general, I would find it really helpful if the tutorial for Seaborn contained one or two examples of user-defined plot functions passed to map or map_dataframe.

like image 473
tsawallis Avatar asked Apr 30 '15 12:04

tsawallis


1 Answers

This is @mwaskom's answer, and works a treat (see comments):

Just change the my_errorbar function so that it pops the data out of the keyword dict:

def my_errorbar(*args, **kwargs):
    data = kwargs.pop('data')
    errors = np.vstack([data['error_min'], 
                        data['error_max']])

    print(errors)

    plt.errorbar(data[args[0]], 
                 data[args[1]], 
                 yerr=errors,
                 **kwargs);    
like image 125
tsawallis Avatar answered Sep 19 '22 10:09

tsawallis