Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fixing the seed for torch random_split()

Tags:

pytorch

torch

Is it possible to fix the seed for torch.utils.data.random_split() when splitting a dataset so that it is possible to reproduce the test results?

like image 521
cerebrou Avatar asked Apr 23 '19 22:04

cerebrou


People also ask

What is seed in torch?

torch. seed()[source] Sets the seed for generating random numbers to a non-deterministic random number. Returns a 64 bit number used to seed the RNG.

How do you test a torch seed?

You can use torch. seed() to get the current seed. You might want to check the reproducibility part of the doc though: https://pytorch.org/docs/stable/notes/randomness.html as having the seed most likely won't allow you to reproduce the result if you're using a different machine or using ops that are not deterministic.


2 Answers

You can use torch.manual_seed function to seed the script globally:

import torch
torch.manual_seed(0)

See reproducibility documentation for more information.

If you want to specifically seed torch.utils.data.random_split you could "reset" the seed to it's initial value afterwards. Simply use torch.initial_seed() like this:

torch.manual_seed(torch.initial_seed())

AFAIK pytorch does not provide arguments like seed or random_state (which could be seen in sklearn for example).

like image 76
Szymon Maszke Avatar answered Sep 23 '22 18:09

Szymon Maszke


As you can see from the documentation is possible to pass a generator to random_split

random_split(range(10), [3, 7], generator=torch.Generator().manual_seed(42))
like image 34
Matteo Pennisi Avatar answered Sep 26 '22 18:09

Matteo Pennisi