Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to save and load random number generator state in Pytorch?

I am training a DL model in Pytorch, and want to train my model in a deterministic way. As written in this official guide, I set random seeds like this:

np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

Now, my training is long and i want to save, then later load everything, including the RNGs. I use torch.save and torch.load_state_dict for the model and the optimizer.

How can the random number generators be saved & loaded?

like image 598
hajduistvan Avatar asked Mar 11 '19 08:03

hajduistvan


1 Answers

You can use torch.get_rng_state and torch.set_rng_state

When calling torch.get_rng_state you will get your random number generator state as a torch.ByteTensor.

You can then save this tensor somewhere in a file and later you can load and use torch.set_rng_state to set the random number generator state.


When using numpy you can of course do the same there using:
numpy.random.get_state and numpy.random.set_state

like image 87
MBT Avatar answered Sep 29 '22 18:09

MBT