I implemented an image classification network to classify a dataset of 100 classes by using Alexnet as a pretrained model and changing the final output layers. I noticed when I was loading my data like
trainloader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=False)
, I was getting accuracy on validation dataset around 2-3 % for around 10 epochs but when I just changed shuffle=True
and retrained the network, the accuracy jumped to 70% in the first epoch itself.
I was wondering if it happened because in the first case the network was being shown one example after the other continuously for just one class for few instances resulting in network making poor generalizations during training or is there some other reason behind it?
But, I did not expect that to have such a drastic impact.
P.S: All the code and parameters were exactly the same for both the cases except changing the shuffle option.
PyTorch did many great things, and one of them is the DataLoader
class.
DataLoader
class takes the dataset (data), sets the batch_size
(which is how many samples per batch to load), and invokes the sampler from a list of classes:
DistributedSampler
SequentialSampler
RandomSampler
SubsetRandomSampler
WeightedRandomSampler
BatchSampler
The key thing samplers do is how they implement the iter()
method.
In case of SequentionalSampler
it looks like this:
def __iter__(self):
return iter(range(len(self.data_source)))
This returns an iterator, for every item in the data_source.
When you set shuffle=True
that would not use SequentionalSampler
, but instead the RandomSampler
.
And this may improve the learning process.
Yes it totally can affect the result! Shuffling the order of the data that we use to fit the classifier is so important, as the batches between epochs do not look alike.
Checking the Data Loader Documentation it says: "shuffle (bool, optional) – set to True to have the data reshuffled at every epoch"
In any case, it will make the model more robust and avoid over/underfitting.
In your case this heavy increase of accuracy (from the lack of awareness of the dataset) probably is due to how the dataset is "organised" as maybe, as an example, each category goes to a different batch, and in every epoch, a batch contains the same category, which derives to a very bad accuracy when you are testing.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With