Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Change marker style by a dataframe column (categorical) in seaborn stripplot

I was looking to visualise a categorical variable as marker style in seaborn stripplot, but it does not seem to be possible easily. Is there an easy way to do this. For example, I'm trying to run this code

tips = sns.load_dataset("tips")
sns.stripplot(x="day", y="total_bill", hue="time", style="sex", jitter=True, data=tips)

which fails. An alternative is to use relplot which does provide the option but has no way to insert jitter which makes the visualisation less nice.

sns.relplot(x="day", y="total_bill", hue="time", data=tips, style="sex")

works providing this

enter image description here

Is there any way of doing this using stripplot/catplot/swarmplot?

EDIT: This question is related. However the solution there does not seem to allow generation of a legend for size (and is quite dated).

like image 752
Devil Avatar asked Sep 13 '25 22:09

Devil


2 Answers

sns.relplot is a figure-level function which relies on the axes-level function sns.scatterplot. sns.scatterplot has a parameter x_jitter which unfortunately currently has no effect (seaborn 0.11.2).

You can mimic the functionality by grasping the positions of the points, add some random jitter and assigning these positions again.

Here is an example:

from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np

tips = sns.load_dataset("tips")
ax = sns.scatterplot(x="day", y="total_bill", hue="time", data=tips, style="sex")
for points in ax.collections:
    vertices = points.get_offsets().data
    if len(vertices) > 0:
        vertices[:, 0] += np.random.uniform(-0.3, 0.3, vertices.shape[0])
        points.set_offsets(vertices)
xticks = ax.get_xticks()
ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5) # the limits need to be moved to show all the jittered dots
sns.move_legend(ax, bbox_to_anchor=(1.01, 1.02), loc='upper left')  # needs seaborn 0.11.2
sns.despine()
plt.tight_layout()
plt.show()

sns.scatterplot with jitter

With sns.relplot you could iterate through all the subplots:

g = sns.relplot(x="day", y="total_bill", hue="time", data=tips, style="sex")
for ax in g.axes.flat:
    for points in ax.collections:
        vertices = points.get_offsets().data
        if len(vertices) > 0:
            vertices[:, 0] += np.random.uniform(-0.3, 0.3, vertices.shape[0])
            points.set_offsets(vertices)
    xticks = ax.get_xticks()
    ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5) # the limits need to be moved to show all the jittered dots
plt.show()
like image 117
JohanC Avatar answered Sep 16 '25 12:09

JohanC


Using seaborn's object interface (starting version 0.12), you can now do this as follows:

import seaborn.objects as so
(
    so.Plot(tips, x="day", y="total_bill", color="time", marker="sex")
    .add(so.Dot(pointsize=7, edgecolor="white"), so.Jitter(.3))
).theme(
    {**sns.axes_style("ticks"), 'axes.spines.right': False, 'axes.spines.top': False}
)

Stripplot produced by the code snippte

like image 27
astoeriko Avatar answered Sep 16 '25 10:09

astoeriko