Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

3D Scatterplot with strings in Python

I tried to do a 3D scatter plot in Python with string categories (i.e. activation functions and solvers for a neural network) on x and y and floating numbers (i.e. accuracy score of NN) on the z axis.

The following example raises the error: ValueError: could not convert string to float: 'str1'

I followed this documentation for 3D plots: https://matplotlib.org/mpl_toolkits/mplot3d/tutorial.html

Any ideas, what might be the problem ? Many thanks in advance!

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
xs=['str1', 'str2']
print(type(xs))
ys=['str3', 'str4']
print(type(ys))
zs=[1,2]
ax.scatter(xs, ys, zs)
like image 744
S.Maria Avatar asked Jan 09 '19 15:01

S.Maria


People also ask

How do you make a 3 dimensional scatter plot in python?

To get a 3D plot, we can use fig. add_subplot(111, projection='3d') method to instantiate the axis. After that, we can use the scatter method to draw different data points on the x, y, and z axes.

How do you make a 3 dimensional scatter plot?

After adding data, go to the 'Traces' section under the 'Structure' menu on the left-hand side. Choose the 'Type' of trace, then choose '3D Scatter' under '3D' chart type. Next, select 'X', 'Y' and 'Z' values from the dropdown menus. This will create a 3D scatter trace, as seen below.


1 Answers

You are trying to pass categorical values (strings) as the x and y arguments. This would work for 1d scatter plot but in 3d, you need to define the span/cartesian coordinates. What you mainly want to have is the strings as the x and y-axis ticklabels. To get the desired plot, what you can do is to first plot the numeric values and then re-assign the ticklabels as per your string values.

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

xs=['str1', 'str2']
ys=['str3', 'str4']
zs=[1,2]

ax.scatter(range(len(xs)), range(len(xs)), zs)
ax.set(xticks=range(len(xs)), xticklabels=xs,
       yticks=range(len(xs)), yticklabels=xs) 

You can also set the tick labels using

plt.xticks(range(len(xs)), xs)
plt.yticks(range(len(ys)), ys)

The first option using ax however allows you to do the same in one line.

enter image description here

like image 169
Sheldore Avatar answered Oct 20 '22 07:10

Sheldore