Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get the filename of a sample from a DataLoader?

I need to write a file with the result of the data test of a Convolutional Neural Network that I trained. The data include speech data collection. The file format needs to be "file name, prediction", but I am having a hard time to extract the file name. I load the data like this:

import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

TEST_DATA_PATH = ...

trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

test_dataset = torchvision.datasets.MNIST(
    root=TEST_DATA_PATH,
    train=False,
    transform=trans,
    download=True
)

test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)

and I am trying to write to the file as follows:

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        file = os.listdir(TEST_DATA_PATH + "/all")[i]
        format = file + ", " + str(predicted.item()) + '\n'
        f.write(format)
f.close()

The problem with os.listdir(TESTH_DATA_PATH + "/all")[i] is that it is not synchronized with the loaded files order of test_loader. What can I do?

like image 486
Almog Levi Avatar asked Jun 21 '19 07:06

Almog Levi


People also ask

What does a DataLoader return?

DataLoader in your case is supposed to return a list. The output of DataLoader is (inputs batch, labels batch) . e.g. Here, the 64 labels corresponds to 64 images in the batch.

What is sampler in DataLoader?

Samplers are just extensions of the torch. utils. data. Sampler class, i.e. they are passed to a PyTorch Dataloader. The purpose of samplers is to determine how batches should be formed.

What is the difference between dataset and DataLoader?

Dataset stores the samples and their corresponding labels, and DataLoader wraps an iterable around the Dataset to enable easy access to the samples.

How do I access DataLoader?

The Salesforce Data Loader can be installed by navigating to the setup menu in Salesforce, and heading to the Data Loader tab, here you will find download links both for Windows & Mac.


2 Answers

Well, it depends on how your Dataset is implemented. For instance, in the torchvision.datasets.MNIST(...) case, you cannot retrieve the filename simply because there is no such thing as the filename of a single sample (MNIST samples are loaded in a different way).

As you did not show your Dataset implementation, I'll tell you how this could be done with the torchvision.datasets.ImageFolder(...) (or any torchvision.datasets.DatasetFolder(...)):

f = open("test_y", "w")
with torch.no_grad():
    for i, (images, labels) in enumerate(test_loader, 0):
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        sample_fname, _ = test_loader.dataset.samples[i]
        f.write("{}, {}\n".format(sample_fname, predicted.item()))
f.close()

You can see that the path of the file is retrieved during the __getitem__(self, index), especifically here.

If you implemented your own Dataset (and perhaps would like to support shuffle and batch_size > 1), then I would return the sample_fname on the __getitem__(...) call and do something like this:

for i, (images, labels, sample_fname) in enumerate(test_loader, 0):
    # [...]

This way you wouldn't need to care about shuffle. And if the batch_size is greater than 1, you would need to change the content of the loop for something more generic, e.g.:

f = open("test_y", "w")
for i, (images, labels, samples_fname) in enumerate(test_loader, 0):
    outputs = model(images)
    pred = torch.max(outputs, 1)[1]
    f.write("\n".join([
        ", ".join(x)
        for x in zip(map(str, pred.cpu().tolist()), samples_fname)
    ]) + "\n")
f.close()
like image 116
Berriel Avatar answered Oct 06 '22 00:10

Berriel


In general case DataLoader is there to provide you the batches from the Dataset(s) it has inside.

AS @Barriel mentioned in case of single/multi-label classification problems, the DataLoader doesn't have image file name, just the tensors representing the images , and the classes / labels.

However, DataLoader constructor when loading objects can take small things (together with the Dataset you may pack the targets/labels and the file names if you like) , even a dataframe

This way, the DataLoader may somehow grab that what you need.

like image 33
prosti Avatar answered Oct 06 '22 00:10

prosti