Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get the total number of batch iteration from pytorch dataloader?

I have a question that How to get the total number of batch iteration from pytorch dataloader?

The following is a common code for training

for i, batch in enumerate(dataloader):

Then, is there any method to get the total number of iteration for the "for loop"?

In my NLP problem, the total number of iteration is different from int(n_train_samples/batch_size)...

For example, if I truncate train data only 10,000 samples and set the batch size as 1024, then 363 iteration occurs in my NLP problem.

I wonder how to get the number of total iteration in "the for-loop".

Thank you.

like image 617
Hyunseung Kim Avatar asked Feb 04 '26 06:02

Hyunseung Kim


2 Answers

len(dataloader) returns the total number of batches. It depends on the __len__ function of your dataset, so make sure it is set correctly.

like image 111
hkchengrex Avatar answered Feb 06 '26 03:02

hkchengrex


There is one additional parameter when creating the dataloader. It is called drop_last.

If drop_last=True then length is number_of_training_examples // batch_size. If drop_last=False it may be number_of_training_examples // batch_size +1 .

BS=128
ds_train = torchvision.datasets.CIFAR10('/data/cifar10', download=True, train=True, transform=t_train)
dl_train = DataLoader( ds_train, batch_size=BS, drop_last=True, shuffle=True)

For predefined datasets you may get the number of examples like:

# number of examples
len(dl_train.dataset) 

The correct number of batches inside dataloader is always:

# number of batches
len(dl_train) 
like image 35
prosti Avatar answered Feb 06 '26 02:02

prosti



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!