Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In Torch how do I create a 1-hot tensor from a list of integer labels?

I have a byte tensor of integer class labels, e.g. from the MNIST data set.

 1
 7
 5
[torch.ByteTensor of size 3]

How do use it to create a tensor of 1-hot vectors?

 1  0  0  0  0  0  0  0  0  0
 0  0  0  0  0  0  1  0  0  0
 0  0  0  0  1  0  0  0  0  0
[torch.DoubleTensor of size 3x10]

I know I could do this with a loop, but I'm wondering if there's any clever Torch indexing that will get it for me in a single line.

like image 694
W.P. McNeill Avatar asked Aug 14 '15 15:08

W.P. McNeill


People also ask

How do you make one hot encoding in PyTorch?

In the above example, we try to implement the one hot() encoding function as shown here first; we import all required packages such as a torch. After that, we created a tenor by using the rand() function as shown, and finally, we applied the one hot() function with the argmax() function as shown.

How do you declare a torch tensor?

To create a tensor with pre-existing data, use torch.tensor() . To create a tensor with specific size, use torch.* tensor creation ops (see Creation Ops). To create a tensor with the same size (and similar types) as another tensor, use torch.*_like tensor creation ops (see Creation Ops).

What is a label tensor?

public class TensorLabel. TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis. For example, an image classification model may have an output tensor with shape as {1, 10}, where 1 is the batch size and 10 is the number of categories.

What is the default data type of a PyTorch tensor?

FloatTensor ; by default, PyTorch tensors are populated with 32-bit floating point numbers.


2 Answers

An alternate method is to shuffle rows from an identity matrix:

indicies = torch.LongTensor{1,7,5}
one_hot = torch.eye(10):index(1, indicies)

This was not my idea, I found it in karpathy/char-rnn.

like image 98
Tarquinnn Avatar answered Sep 23 '22 14:09

Tarquinnn


indices = torch.LongTensor{1,7,5}:view(-1,1)
one_hot = torch.zeros(3, 10)
one_hot:scatter(2, indices, 1)

You can find the documentation for scatter in the torch/torch7 github readme (in the master branch).

like image 25
smhx Avatar answered Sep 20 '22 14:09

smhx