Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Add seaborn.palplot axes to existing figure for visualisation of different color palettes

Adding seaborn figures to subplots is usually done by passing 'ax' when creating the figure. For instance:

sns.kdeplot(x, y, cmap=cmap, shade=True, cut=5, ax=ax)

This method, however, doesn't apply to seaborn.palplot, which visualizes seaborn color palettes. My goal is to create a figure of different color palettes for scalable color comparison and presentation. This image roughly shows the figure I'm trying to create [source].

A possibly related answer describes a method of creating a seaborn figure and copying the axes to another figure. I haven't been able to apply this method to the palplot figures, and would like to know if there is a quick way to force them into the existing figure.

Here's my minimum working example, now still generating separate figures.

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

fig1 = plt.figure()
length, n_colors = 12, 50  # amount of subplots and colors per subplot
start_colors = np.linspace(0, 3, length)
for i, start_color in enumerate(start_colors):
    ax = fig1.add_subplot(length, 1, i + 1)
    colors = sns.cubehelix_palette(n_colors=n_colors, start=start_color,
                                   rot=0, light=0.4, dark=0.8)
    sns.palplot(colors)
plt.show(fig1)

Ultimately, to make the plot more informative, it would be great to print the RGB values stored in colors (list-like) evenly spaced over the palplots, but I don't know if this is easily implemented due to the unusual way of plotting in palplot.

Any help would be greatly appreciated!

like image 201
ddelange Avatar asked Nov 07 '22 06:11

ddelange


1 Answers

As you've probably already found, there's little documentation for the palplot function, but I've lifted directly from the seaborn github repo here:

def palplot(pal, size=1):
    """Plot the values in a color palette as a horizontal array.
    Parameters
    ----------
    pal : sequence of matplotlib colors
        colors, i.e. as returned by seaborn.color_palette()
    size :
        scaling factor for size of plot
    """
    n = len(pal)
    f, ax = plt.subplots(1, 1, figsize=(n * size, size))
    ax.imshow(np.arange(n).reshape(1, n),
              cmap=mpl.colors.ListedColormap(list(pal)),
              interpolation="nearest", aspect="auto")
    ax.set_xticks(np.arange(n) - .5)
    ax.set_yticks([-.5, .5])
    # Ensure nice border between colors
    ax.set_xticklabels(["" for _ in range(n)])
    # The proper way to set no ticks
    ax.yaxis.set_major_locator(ticker.NullLocator())

So, it doesn't return any axes or figure objects, or allow you to specify an axes object to write into. You could make your own, as follows, by adding the ax argument and a conditional in case it's not provided. Depending on context, you may need the included imports as well.

def my_palplot(pal, size=1, ax=None):
    """Plot the values in a color palette as a horizontal array.
    Parameters
    ----------
    pal : sequence of matplotlib colors
        colors, i.e. as returned by seaborn.color_palette()
    size :
        scaling factor for size of plot
    ax :
        an existing axes to use
    """

    import numpy as np
    import matplotlib as mpl
    import matplotlib.pyplot as plt
    import matplotlib.ticker as ticker

    n = len(pal)
    if ax is None:
        f, ax = plt.subplots(1, 1, figsize=(n * size, size))
    ax.imshow(np.arange(n).reshape(1, n),
              cmap=mpl.colors.ListedColormap(list(pal)),
              interpolation="nearest", aspect="auto")
    ax.set_xticks(np.arange(n) - .5)
    ax.set_yticks([-.5, .5])
    # Ensure nice border between colors
    ax.set_xticklabels(["" for _ in range(n)])
    # The proper way to set no ticks
    ax.yaxis.set_major_locator(ticker.NullLocator())

This function should work like you expect when you include the 'ax' argument. To implement this in your example:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

fig1 = plt.figure()
length, n_colors = 12, 50  # amount of subplots and colors per subplot
start_colors = np.linspace(0, 3, length)
for i, start_color in enumerate(start_colors):
    ax = fig1.add_subplot(length, 1, i + 1)
    colors = sns.cubehelix_palette(
        n_colors=n_colors, start=start_color, rot=0, light=0.4, dark=0.8
    )
    my_palplot(colors, ax=ax)
plt.show(fig1)

example results

like image 189
mightypile Avatar answered Nov 11 '22 07:11

mightypile