Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Discrete Color Bar with Tick labels in between colors

I am trying to plot some data with a discrete color bar. I was following the example given (https://gist.github.com/jakevdp/91077b0cae40f8f8244a) but the issue is this example does not work 1-1 with different spacing. For example, the spacing in the example in the link is for only increasing by 1 but my data is increasing by 0.5. You can see the output from the code I have.Example output from some random data. Any help with this would be appreciated. I know I am missing something key here but cant figure it out.

import matplotlib.pylab as plt
import numpy as np

def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)
num=11

x = np.random.randn(40)
y = np.random.randn(40)
c = np.random.randint(num, size=40)


plt.figure(figsize=(10,7.5))

plt.scatter(x, y, c=c, s=50, cmap=discrete_cmap(num, 'jet'))
plt.colorbar(ticks=np.arange(0,5.5,0.5))
plt.clim(-0.5, num - 0.5)
plt.show()
like image 920
Charanjit Pabla Avatar asked Feb 05 '23 00:02

Charanjit Pabla


2 Answers

Not sure what version of matplotlib/pyplot introduced this, but plt.get_cmap now supports an int argument specifying the number of colors you want to get, for discrete colormaps. This automatically results in the colorbar being discrete. By the way, pandas has an even better handling of the colorbar.

import numpy as np
from matplotlib import pyplot as plt
plt.style.use('ggplot')
# remove if not using Jupyter/IPython
%matplotlib inline

# choose number of clusters and number of points in each cluster
n_clusters = 5
n_samples = 20

# there are fancier ways to do this
clusters = np.array([k for k in range(n_clusters) for i in range(n_samples)])

# generate the coordinates of the center 
# of each cluster by shuffling a range of values
clusters_x = np.arange(n_clusters)
clusters_y = np.arange(n_clusters)

np.random.shuffle(clusters_x)
np.random.shuffle(clusters_y)

# get dicts like cluster -> center coordinate
x_dict = dict(enumerate(clusters_x))
y_dict = dict(enumerate(clusters_y))

# get coordinates of cluster center for each point
x = np.array(list(x_dict[k] for k in clusters)).astype(float)
y = np.array(list(y_dict[k] for k in clusters)).astype(float)

# add noise
x += np.random.normal(scale=0.5, size=n_clusters*n_samples)
y += np.random.normal(scale=0.5, size=n_clusters*n_samples)

### Finally, plot
fig, ax = plt.subplots(figsize=(12,8))

# get discrete colormap
cmap = plt.get_cmap('viridis', n_clusters)

# scatter points
scatter = ax.scatter(x, y, c=clusters, cmap=cmap)

# scatter cluster centers
ax.scatter(clusters_x, clusters_y, c='red')

# add colorbar
cbar = plt.colorbar(scatter)

# set ticks locations (not very elegant, but it works):
# - shift by 0.5
# - scale so that the last value is at the center of the last color
tick_locs = (np.arange(n_clusters) + 0.5)*(n_clusters-1)/n_clusters
cbar.set_ticks(tick_locs)

# set tick labels (as before)
cbar.set_ticklabels(np.arange(n_clusters))

Sample output plot

like image 61
Giorgio Balestrieri Avatar answered Feb 07 '23 15:02

Giorgio Balestrieri


Ok so this is the hack I found for my own question. I am sure there is a better way to do this but this works for what I am doing. Feel free to suggest a better way to do this.

import numpy as np
import matplotlib.pylab as plt

def discrete_cmap(N, base_cmap=None):
    """Create an N-bin discrete colormap from the specified input map"""

    # Note that if base_cmap is a string or None, you can simply do
    #    return plt.cm.get_cmap(base_cmap, N)
    # The following works for string, None, or a colormap instance:

    base = plt.cm.get_cmap(base_cmap)
    color_list = base(np.linspace(0, 1, N))
    cmap_name = base.name + str(N)
    return base.from_list(cmap_name, color_list, N)

num=11

plt.figure(figsize=(10,7.5))

x = np.random.randn(40)
y = np.random.randn(40)
c = np.random.randint(num, size=40)


plt.scatter(x, y, c=c, s=50, cmap=discrete_cmap(num, 'jet'))
cbar=plt.colorbar(ticks=range(num))
plt.clim(-0.5, num - 0.5)
cbar.ax.set_yticklabels(np.arange(0.0,5.5,0.5))
plt.show()

For some reason I cannot upload the image associated with the code above. I get an error when uploading so not sure how to show the final example. But simply I set the color bar axes for tick labels for a vertical color bar and passed in the labels I want and it produced the correct output. Correct image

like image 39
Charanjit Pabla Avatar answered Feb 07 '23 15:02

Charanjit Pabla