Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How to add parameters in module class in pytorch custom model?

I tried to find the answer but I can't.

I make a custom deep learning model using pytorch. For example,

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.nn_layers = nn.ModuleList()
        self.layer = nn.Linear(2,3).double()

        self.bias = torch.nn.Parameter(torch.randn(3))


    def forward(self, x):
        activation = torch.tanh
        output = activation(self.layer(x)) + self.bias

        return output

If I print

model = Net()

it does not contains model.bias, so optimizer = optimizer.Adam(model.parameters()) does not update model.bias. How can I go through this? Thanks!

like image 845
CSH Avatar asked Dec 08 '19 09:12


People also ask

How do you load parameters into a model PyTorch?

Saving and Loading Model Weights To load model weights, you need to create an instance of the same model first, and then load the parameters using load_state_dict() method. be sure to call model. eval() method before inferencing to set the dropout and batch normalization layers to evaluation mode.

How do you get parameters in PyTorch?

PyTorch doesn't have a utility function (at least at the moment!) to count the number of model parameters, but there is a property of the model class that you can use to get the model parameters. model. parameters(): PyTorch modules have a a method called parameters() which returns an iterator over all the parameters.

What is model parameters () PyTorch?

The PyTorch parameter is a layer made up of nn or a module. A parameter that is assigned as an attribute inside a custom model is registered as a model parameter and is thus returned by the caller model. parameters(). We can say that a Parameter is a wrapper over Variables that are formed.

1 Answers

You need to register your parameters:

self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(3)))
like image 104
Shai Avatar answered Nov 05 '22 01:11
