Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

t-SNE map into 2D or 3D plot

features =  ["Ask1", "Bid1", "smooth_midprice", "BidSize1", "AskSize1"]

client = InfluxDBClient(host='127.0.0.1', port=8086, database='data',
                        username=username, password=password)

series = "DCIX_2016_11_15"

sql = "SELECT * FROM {} where time  >= '{}' AND time <= '{}' ".format(series,FROMT,TOT)

df = pd.DataFrame(client.query(sql).get_points())

#Separating out the features
X = df.loc[:, features].values

# Standardizing the features
X = StandardScaler().fit_transform(X)

tsne = TSNE(n_components=3, n_jobs=5).fit_transform(X)

I would like map my 5 features into a 2D or 3D plot. I am a bit confused how to do that. How can I build a plot from that information?

like image 295
Jeremie Avatar asked Jan 03 '23 03:01

Jeremie


1 Answers

You already have most of the work done. t-SNE is a common visualization for understanding high-dimensional data, and right now the variable tsne is an array where each row represents a set of (x, y, z) coordinates from the obtained embedding. You could use other visualizations if you would like, but t-SNE is probably a good starting place.

As far as actually seeing the results, even though you have the coordinates available you still need to plot them somehow. The matplotlib library is a good option, and that's what we'll use here.

To plot in 2D you have a couple of options. You can either keep most of your code the same and simply perform a 2D t-SNE with

tsne = TSNE(n_components=2, n_jobs=5).fit_transform(X)

Or you can just use the components you have and only look at two of them at a time. The following snippet should handle either case:

import matplotlib.pyplot as plt

plt.scatter(*zip(*tsne[:,:2]))
plt.show()

The zip(*...) transposes your data so that you can pass the x coordinates and the y coordinates individually to scatter(), and the [:,:2] piece selects two coordinates to view. You could ignore it if your data is already 2D, or you could replace it with something like [:,[0,2]] to view, for example, the 0th and 2nd features in higher-dimensional data rather than just the first 2.

To plot in 3D the code looks much the same, at least for a minimal version.

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

ax.scatter(*zip(*tsne))
plt.show()

The main differences are a use of 3D plotting libraries and making a 3D subplot.

Adding color: t-SNE visualizations are typically more helpful if they're color-coded somehow. One example might be the smooth midprice you currently have stored in X[:,2]. For exploratory visualizations, I find 2D plots more helpful, so I'll use that as the example:

plt.scatter(*zip(*tsne[:,:2]), c=X[:,2])

You still need the imports and whatnot, but by passing the keyword argument c you can color code the scatter plot. To adjust how that numeric data is displayed, you could use a different color map like so:

plt.scatter(*zip(*tsne[:,:2]), c=X[:,2], cmap='RdBu')

As the name might suggest, this colormap consists of a gradient between red and blue, and the lower values of X[:,2] will correspond to red.

like image 65
Hans Musgrave Avatar answered Jan 13 '23 09:01

Hans Musgrave