Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I preserve order of axis in scatter plot when using categorical values?

I want to create a scatter plot that summarises my data in ntiles. As scatter plot can't take Interval type as an axis parameter I convert the values to strings but then this loses the order of the Intervals, see the x-axis below is not ordered from low to high. How can I preserve the order?

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np


n_tile = 5
np.random.seed(0)
x = np.random.normal(150, 70, 3000,)
y = np.random.normal(1, 0.3, 3000)
r = np.random.normal(0.4, 0.1, 3000)

plot_data = pd.DataFrame({
            'x': x,
            'y': y,
            'r': r
                })
plot_data['x_group'] = pd.qcut(plot_data['x'], n_tile, duplicates='drop')
plot_data['y_group'] = pd.qcut(plot_data['y'], n_tile, duplicates='drop')
plot_data_grouped = plot_data.groupby(['x_group','y_group'], as_index=False).agg({'r':['mean','count']})
plot_data_grouped.columns = ['x','y','mean','count']

cmap = plt.cm.rainbow
norm = matplotlib.colors.Normalize(vmin=0, vmax=1)

plt.figure(figsize=(10,10))
plt.scatter(x=[str(x) for x in plot_data_grouped['x']], 
            y=[str(x) for x in plot_data_grouped['y']], 
            s=plot_data_grouped["count"], 
            c=plot_data_grouped['mean'], cmap="RdYlGn", edgecolors="black")
plt.show()

enter image description here

like image 459
wilsonm2 Avatar asked Mar 03 '23 14:03

wilsonm2


1 Answers

Sometimes, it is better to upgrade your current development packages. As your virtual-env has a local matplotlib installed. After sourcing activates, upgrade matplotlib.


For this, open terminal or command prompt with administrative privileges and try to upgrade pip and matplotlib versions using the following commands:

  • python -m pip install --upgrade pip
  • python -m pip install --upgrade matplotlib

On the other hand, using matplotlib, you can get or set the current tick locations and labels of either of axes ( i.e. x-axis or y-axis).


I am giving you a very simple example of your given data to plot in order along both axes. To preserve the orders along axes, you can simply use:

  • matplotlib.pyplot.xticks
  • matplotlib.pyplot.yticks

You can use this technique to solve your problem with and without upgrading matplotlib. Especially for your specified matplotlib==2.1.1 version.


import matplotlib.pyplot as plt

x_axis_values = ['(-68.18100000000001, 89.754]', '(89.754, 130.42]', '(130.42, 165.601]', '(165.601, 205.456]',
                 '(205.456, 371.968]']

y_axis_values = ['(-0.123, 0.749]', '(0.749, 0.922]', '(0.922, 1.068]', '(1.068, 1.253]', '(1.253, 2.14]']

# Try to sort the values, before passing to [xticks, yticks]
# or in which order, you want them along axes
plt.xticks(ticks=range(len(x_axis_values)), labels=x_axis_values)
plt.yticks(ticks=range(len(y_axis_values)), labels=y_axis_values)

# plt.scatter(x_axis_values, y_axis_values)
plt.xlabel('Values')
plt.ylabel('Indices')

plt.show()

Here is the output of this simple example. You can see the values along both the x-axis and the y-axis. The purpose of the given figure is only to specify the values along with both axes:

enter image description here


For your given code, I have updated some of your code as follows:

import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np

n_tile = 5
np.random.seed(0)
x = np.random.normal(150, 70, 3000, )
y = np.random.normal(1, 0.3, 3000)
r = np.random.normal(0.4, 0.1, 3000)

plot_data = pd.DataFrame({
    'x': x,
    'y': y,
    'r': r
})
plot_data['x_group'] = pd.qcut(plot_data['x'], n_tile, duplicates='drop')
plot_data['y_group'] = pd.qcut(plot_data['y'], n_tile, duplicates='drop')
plot_data_grouped = plot_data.groupby(['x_group', 'y_group'], as_index=False).agg({'r': ['mean', 'count']})
plot_data_grouped.columns = ['x', 'y', 'mean', 'count']

cmap = plt.cm.rainbow
norm = matplotlib.colors.Normalize(vmin=0, vmax=1)

########################################################
##########  Updated Portion of the Code ################

x_axis_values = [str(x) for x in plot_data_grouped['x']]
y_axis_values = [str(x) for x in plot_data_grouped['y']]

plt.figure(figsize=(10, 10))
# Unique Values have only length == 5
plt.xticks(ticks=range(5), labels=sorted(np.unique(x_axis_values)))
plt.yticks(ticks=range(5), labels=sorted(np.unique(y_axis_values)))

plt.scatter(x=x_axis_values,
            y=y_axis_values,
            s=plot_data_grouped["count"],
            c=plot_data_grouped['mean'], cmap="RdYlGn", edgecolors="black")

plt.show()
########################################################

Now you can see the output is as required:

enter image description here

like image 106
Muhammad Usman Bashir Avatar answered Mar 05 '23 17:03

Muhammad Usman Bashir