Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Automated legend creation for 3D plot

I'm trying to update below function to report the clusters info via legend:

color_names = ["red", "blue", "yellow", "black", "pink", "purple", "orange"]

def plot_3d_transformed_data(df, title, colors="red"):
 
  ax = plt.figure(figsize=(12,10)).gca(projection='3d')
  #fig = plt.figure(figsize=(8, 8))
  #ax = fig.add_subplot(111, projection='3d')
  

  if type(colors) is np.ndarray:
    for cname, class_label in zip(color_names, np.unique(colors)):
      X_color = df[colors == class_label]
      ax.scatter(X_color[:, 0], X_color[:, 1], X_color[:, 2], marker="x", c=cname, label=f"Cluster {class_label}" if type(colors) is np.ndarray else None)
  else:
      ax.scatter(df.Type, df.Length, df.Freq, alpha=0.6, c=colors, marker="x", label=str(clusterSizes)  )

  ax.set_xlabel("PC1: Type")
  ax.set_ylabel("PC2: Length")
  ax.set_zlabel("PC3: Frequency")
  ax.set_title(title)
  
  if type(colors) is np.ndarray:
    #ax.legend()
    plt.gca().legend()
    
  
  plt.legend(bbox_to_anchor=(1.04,1), loc="upper left")
  plt.show()

So I call my function to visualize the clusters patterns by:

plot_3d_transformed_data(pdf_km_pred,
                         f'Clustering rare URL parameters for data of date: {DATE_FROM}  \nMethod: KMeans over PCA \nn_clusters={n_clusters} , Distance_Measure={DistanceMeasure}',
                         colors=pdf_km_pred.prediction_km)

print(clusterSizes)

Sadly I can't show the legend, and I have to print clusters members manually under the 3D plot. This is the output without legend with the following error: No handles with labels found to put in legend. enter image description here

I check this post, but I couldn't figure out what is the mistake in function to pass the cluster label list properly. I want to update the function so that I can demonstrate cluster labels via clusterSizes.index and their scale via clusterSizes.size

Expected output: As here suggests better using legend_elements() to determine a useful number of legend entries to be shown and return a tuple of handles and labels automatically.

Update: As I mentioned in the expected output should contain one legend for cluster labels and the other legend for cluster size (number of instances in each cluster). It might report this info via single legend too. Please see below example for 2D: img

like image 497
Mario Avatar asked Nov 07 '22 00:11

Mario


1 Answers

In the function to visualize the clusters, you need ax.legend instead of plt.legend

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d.axes3d import Axes3D
import numpy as np
import pandas as pd

color_names = ["red", "blue", "yellow", "black", "pink", "purple", "orange"]

def plot_3d_transformed_data(df, title, colors="red"):
 
  ax = plt.figure(figsize=(12,10)).gca(projection='3d')
  #fig = plt.figure(figsize=(8, 8))
  #ax = fig.add_subplot(111, projection='3d')
  

  if type(colors) is np.ndarray:
    for cname, class_label in zip(color_names, np.unique(colors)):
      X_color = df[colors == class_label]
      ax.scatter(X_color[:, 0], X_color[:, 1], X_color[:, 2], marker="x", c=cname, label=f"Cluster {class_label}" if type(colors) is np.ndarray else None)
  else:
      ax.scatter(df.Type, df.Length, df.Freq, alpha=0.6, c=colors, marker="x", label=str(clusterSizes)  )

  ax.set_xlabel("PC1: Type")
  ax.set_ylabel("PC2: Length")
  ax.set_zlabel("PC3: Frequency")
  ax.set_title(title)
  
  if type(colors) is np.ndarray:
    #ax.legend()
    plt.gca().legend()
    
  
  ax.legend(bbox_to_anchor=(.9,1), loc="upper left")
  plt.show()

clusterSizes = 10

test_df = pd.DataFrame({'Type':np.random.randint(0,5,10),
                        'Length':np.random.randint(0,20,10),
                        'Freq':np.random.randint(0,10,10),
                        'Colors':np.random.choice(color_names,10)})

plot_3d_transformed_data(test_df,
                         'Clustering rare URL parameters for data of date:haha\nMethod: KMeans over PCA \nn_clusters={n_clusters} , Distance_Measure={DistanceMeasure}',
                         colors=test_df.Colors)

Running this example code, you will have legend handle as expected enter image description here

like image 150
meTchaikovsky Avatar answered Nov 11 '22 16:11

meTchaikovsky