Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python Matplotlib Multi-color Legend Entry

I would like to make a legend entry in a matplotlib look something like this:

enter image description here

It has multiple colors for a given legend item. Code is shown below which outputs a red rectangle. I'm wondering what I need to do to overlay one color ontop of another? Or is there a better solution?

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

red_patch = mpatches.Patch(color='red', label='Foo')
plt.legend(handles=[red_patch])

plt.show()
like image 467
Matt Stokes Avatar asked Aug 09 '15 21:08

Matt Stokes


2 Answers

Perhaps another hack to handle more than two patches. Make sure you order the handles/labels according to the number of columns:

from matplotlib.patches import Patch
import matplotlib.pyplot as plt

fig, ax = plt.subplots()

pa1 = Patch(facecolor='red', edgecolor='black')
pa2 = Patch(facecolor='blue', edgecolor='black')
pa3 = Patch(facecolor='green', edgecolor='black')
#
pb1 = Patch(facecolor='pink', edgecolor='black')
pb2 = Patch(facecolor='orange', edgecolor='black')
pb3 = Patch(facecolor='purple', edgecolor='black')

ax.legend(handles=[pa1, pb1, pa2, pb2, pa3, pb3],
          labels=['', '', '', '', 'First', 'Second'],
          ncol=3, handletextpad=0.5, handlelength=1.0, columnspacing=-0.5,
          loc='center', fontsize=16)

plt.show()

which results in:

like image 177
rionbr Avatar answered Sep 23 '22 01:09

rionbr


I absolutely loved @raphael's answer. Here is a version with circles. Furthermore, I've refactored and trimmed the code a bit to make it more modular.

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

class MulticolorCircles:
    """
    For different shapes, override the ``get_patch`` method, and add the new
    class to the handler map, e.g. via

    ax_r.legend(ax_r_handles, ax_r_labels, handlelength=CONF.LEGEND_ICON_SIZE,
            borderpad=1.2, labelspacing=1.2,
            handler_map={MulticolorCircles: MulticolorHandler})
    """

    def __init__(self, face_colors, edge_colors=None, face_alpha=1,
                 radius_factor=1):
        """
        """
        assert 0 <= face_alpha <= 1, f"Invalid face_alpha: {face_alpha}"
        assert radius_factor > 0, "radius_factor must be positive"
        self.rad_factor = radius_factor
        self.fc = [mcolors.colorConverter.to_rgba(fc, alpha=face_alpha)
                   for fc in face_colors]
        self.ec = edge_colors
        if edge_colors is None:
            self.ec = ["none" for _ in self.fc]
        self.N = len(self.fc)

    def get_patch(self, width, height, idx, fc, ec):
        """
        """
        w_chunk = width / self.N
        radius = min(w_chunk / 2, height) * self.rad_factor
        xy = (w_chunk * idx + radius, radius)
        patch = plt.Circle(xy, radius, facecolor=fc, edgecolor=ec)
        return patch

    def __call__(self, width, height):
        """
        """
        patches = []
        for i, (fc, ec) in enumerate(zip(self.fc, self.ec)):
            patch = self.get_patch(width, height, i, fc, ec)
            patches.append(patch)
        result = PatchCollection(patches, match_original=True)
        #
        return result


class MulticolorHandler:
    """
    """
    @staticmethod
    def legend_artist(legend, orig_handle, fontsize, handlebox):
        """
        """
        width, height = handlebox.width, handlebox.height
        patch = orig_handle(width, height)
        handlebox.add_artist(patch)
        return patch

Sample usage and image, note that some of the legend handles have radius_factor=0.5 because the true size would be too small.

ax_handles, ax_labels = ax.get_legend_handles_labels()
ax_labels.append(AUDIOSET_LABEL)
ax_handles.append(MulticolorCircles([AUDIOSET_COLOR],
                                    face_alpha=LEGEND_SHADOW_ALPHA))
ax_labels.append(FRAUNHOFER_LABEL)
ax_handles.append(MulticolorCircles([FRAUNHOFER_COLOR],
                                    face_alpha=LEGEND_SHADOW_ALPHA))
ax_labels.append(TRAIN_SOURCE_NORMAL_LABEL)
ax_handles.append(MulticolorCircles(SHADOW_COLORS["source"],
                                    face_alpha=LEGEND_SHADOW_ALPHA))
ax_labels.append(TRAIN_TARGET_NORMAL_LABEL)
ax_handles.append(MulticolorCircles(SHADOW_COLORS["target"],
                                    face_alpha=LEGEND_SHADOW_ALPHA))
ax_labels.append(TEST_SOURCE_ANOMALY_LABEL)
ax_handles.append(MulticolorCircles(DOT_COLORS["anomaly_source"],
                                    radius_factor=LEGEND_DOT_RATIO))
ax_labels.append(TEST_TARGET_ANOMALY_LABEL)
ax_handles.append(MulticolorCircles(DOT_COLORS["anomaly_target"],
                                    radius_factor=LEGEND_DOT_RATIO))
#
ax.legend(ax_handles, ax_labels, handlelength=LEGEND_ICON_SIZE,
            borderpad=1.1, labelspacing=1.1,
            handler_map={MulticolorCircles: MulticolorHandler})

enter image description here

like image 29
fr_andres Avatar answered Sep 23 '22 01:09

fr_andres