Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Do we need to use flatten and reshape in Theano if we use a matrix of indexes?

I try to understand Theano implementation of LSTM (at the moment the link does not work for whatever reason but I hope it will be back soon).

In the code I see the following part:

emb = tparams['Wemb'][x.flatten()].reshape([n_timesteps,
                                            n_samples,
                                            options['dim_proj']])

To make it "context independent" I rewrite it in the following way:

e = W[x.flatten()]].reshape([n1, n2, n3])

where dimension of x is (n1, n2) and dimension of W is (N, n3).

So, my assumption is that the code can be rewritten to be shorter. In particular we can just write:

e = W[x]

Or, if we use the original notation it should be:

emb = tparams['Wemb'][x]

Am I right?

To provide a bit more context, x is a 2D array containing integers representing words (for example 27 means "word number 27"). The W in my notation (or tparams['Wemb']) in the original notation is a 2D matrix in which each row corresponds to a word. So, it is a word embedding matrix (Word2Vec) mapping each word to a real valued vector.

like image 275
Roman Avatar asked Oct 18 '22 10:10

Roman


1 Answers

Yes, you are right.

W[x.flatten()]] gives you the rows of W (i.e. words) defined by the values of x. So the result is of shape = (n1*n2,n3). Let's call this "list of words" (not a python list, but just a common speech list). Then reshaping gives you the desired size, where the list of words is subdevided in n1 pages of n2 words.

You achieve the same with W[x], since every of the n2 rows of x gives you one of the n1 pages of the result.

Here's an example program that shows that both expressions are equivalent:

import numpy as np

N = 4
n3 = 5
W = np.arange(n3*N).reshape((N,n3))

print("W = \n", W)

n1 = 2
n2 = 3
x = np.random.randint(low=0, high=N,size=(n1,n2))

print("\nx = \n", x)

print("\ne = \n", W[x.flatten()].reshape([n1, n2, n3]))

print("\nalternativeE = \n", W[x])
like image 99
yar Avatar answered Oct 21 '22 06:10

yar