Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch : How to properly create a list of nn.Linear()

Tags:

python

pytorch

I have created a class that has nn.Module as subclass.

In my class, I have to create N number of linear transformation, where N is given as class parameters.

I therefore proceed as follow :

    self.list_1 = []

    for i in range(N):
        self.list_1.append(nn.Linear(self.x, 1, bias=mlp_bias))

In the forward method, i call these matrices (with list_1[i]) and concat the results.

Two things :

1)

Even though I use model.cuda(), these Linear transform are used on cpu and i get the following error :

RuntimeError: Expected object of type Variable[torch.cuda.FloatTensor] but found type Variable[torch.FloatTensor] for argument #1 'mat2'

I have to do

self.list_1.append(nn.Linear(self.x, 1, bias=mlp_bias).cuda())

This is not required if instead, i do :

self.nn = nn.Linear(self.x, 1, bias=mlp_bias)

and then use self.nn directly.

2)

For more obvious reason, when I print(model) in my main, the Linear matrices in my list arent printed.

Is there any other way. maybe using bmm ? I find it less easy, and i actually want to have my N results separately.

Thank you in advance,

M

like image 724
Mickey Avatar asked May 22 '18 09:05

Mickey


People also ask

What does nn linear do in PyTorch?

PyTorch - nn.Linear nn. Linear(n,m) is a module that creates single layer feed forward network with n inputs and m output. Mathematically, this module is designed to calculate the linear equation Ax = b where x is input, b is output, A is weight.

Does nn linear have activation function?

nn. linear creates a fully connected layer with the default linear activation function.

What is ModuleList in PyTorch?

ModuleList (modules=None)[source] Holds submodules in a list. ModuleList can be indexed like a regular Python list, but modules it contains are properly registered, and will be visible by all Module methods. modules (iterable, optional) – an iterable of modules to add. class MyModule(nn.

Is nn sequential faster?

nn. Sequential is faster than not using it.


1 Answers

You can use nn.ModuleList to wrap your list of linear layers as explained here

self.list_1 = nn.ModuleList(self.list_1)
like image 143
phi Avatar answered Nov 03 '22 06:11

phi