Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does flatten_parameters() do?

Tags:

pytorch

I saw many Pytorch examples using flatten_parameters in the forward function of the RNN

self.rnn.flatten_parameters()

I saw this RNNBase and it is written that it

Resets parameter data pointer so that they can use faster code paths

What does that mean?

like image 765
floyd Avatar asked Nov 09 '18 18:11

floyd


1 Answers

It may not be a full answer to your question. But, if you give a look at the flatten_parameters's source code , you will notice that it calls _cudnn_rnn_flatten_weight in

...
NoGradGuard no_grad;
torch::_cudnn_rnn_flatten_weight(...)
...

is the function that does the job. You will find that what it actually does is copying the model's weights into a vector<Tensor> (check the params_arr declaration) in:

  // Slice off views into weight_buf
  std::vector<Tensor> params_arr;
  size_t params_stride0;
  std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf);

  MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
                    params{params_arr, params_stride0};

And the weights copying in

  // Copy weights
  _copyParams(weight, params);

Also note that they update (or Reset as they explicitly say in docs) the original pointers of weights with the new pointers of params by doing an in-place operation .set_ (_ is their notation for the in-place operations) in orig_param.set_(new_param.view_as(orig_param));

  // Update the storage
  for (size_t i = 0; i < weight.size(0); i++) {
    for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin();
         orig_param_it != weight[i].end() && new_param_it != params[i].end();
         orig_param_it++, new_param_it++) {
      auto orig_param = *orig_param_it, new_param = *new_param_it;
      orig_param.set_(new_param.view_as(orig_param));
    }
  }

And according to n2798 (draft of C++0x)

©ISO/IECN3092

23.3.6 Class template vector

A vector is a sequence container that supports random access iterators. In addition, it supports (amortized)constant time insert and erase operations at the end; insert and erase in the middle take linear time. Storage management is handled automatically, though hints can be given to improve efficiency. The elements of a vector are stored contiguously, meaning that if v is a vector <T, Allocator> where T is some type other than bool, then it obeys the identity&v[n] == &v[0] + n for all 0 <= n < v.size().


In some situations

UserWarning: RNN module weights are not part of single contiguous chunk of memory. This means they need to be compacted at every call, possibly greately increasing memory usage. To compact weights again call flatten_parameters().

They explicitly advise people in code warnings to have a contiguous chunk of memory.

like image 163
ndrwnaguib Avatar answered Oct 06 '22 08:10

ndrwnaguib