I am new to Pytorch. I have been trying to learn how to view my input images before I begin training on my CNN. I am having a very hard time changing the images into a form that can be used with matplotlib.
So far I have tried this:
from multiprocessing import freeze_support
import torch
from torch import nn
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader, Sampler
from torchvision import datasets
from torchvision.transforms import transforms
from torch.optim import Adam
import matplotlib.pyplot as plt
import numpy as np
import PIL
num_classes = 5
batch_size = 100
num_of_workers = 5
DATA_PATH_TRAIN = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\train'
DATA_PATH_TEST = 'C:\\Users\Aeryes\PycharmProjects\simplecnn\images\\test'
trans = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(32),
transforms.CenterCrop(32),
transforms.ToPImage(),
transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])
train_dataset = datasets.ImageFolder(root=DATA_PATH_TRAIN, transform=trans)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_of_workers)
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
print(npimg)
plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
def main():
# get some random training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
# show images
imshow(images)
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
if __name__ == "__main__":
main()
However, this throws and error:
[[0.27058825 0.18431371 0.31764707 ... 0.18823528 0.3882353
0.27450982]
[0.23137254 0.11372548 0.24313724 ... 0.16862744 0.14117646
0.40784314]
[0.25490198 0.19607842 0.30588236 ... 0.27450982 0.25882354
0.34509805]
...
[0.2784314 0.21960783 0.2352941 ... 0.5803922 0.46666667
0.25882354]
[0.26666668 0.16862744 0.23137254 ... 0.2901961 0.29803923
0.2509804 ]
[0.30980393 0.39607844 0.28627452 ... 0.1490196 0.10588235
0.19607842]]
[[0.2352941 0.06274509 0.15686274 ... 0.09411764 0.3019608
0.19215685]
[0.22745097 0.07843137 0.12549019 ... 0.07843137 0.10588235
0.3019608 ]
[0.20392156 0.13333333 0.1607843 ... 0.16862744 0.2117647
0.22745097]
...
[0.18039215 0.16862744 0.1490196 ... 0.45882353 0.36078432
0.16470587]
[0.1607843 0.10588235 0.14117646 ... 0.2117647 0.18039215
0.10980392]
[0.18039215 0.3019608 0.2117647 ... 0.11372548 0.06274509
0.04705882]]]
...
[[[0.8980392 0.8784314 0.8509804 ... 0.627451 0.627451
0.627451 ]
[0.8509804 0.8235294 0.7921569 ... 0.54901963 0.5568628
0.56078434]
[0.7921569 0.7529412 0.7176471 ... 0.47058824 0.48235294
0.49411765]
...
[0.3764706 0.38431373 0.3764706 ... 0.4509804 0.43137255
0.39607844]
[0.38431373 0.39607844 0.3882353 ... 0.4509804 0.43137255
0.39607844]
[0.3882353 0.4 0.39607844 ... 0.44313726 0.42352942
0.39215687]]
[[0.9254902 0.90588236 0.88235295 ... 0.60784316 0.6
0.5921569 ]
[0.88235295 0.85490197 0.8235294 ... 0.5411765 0.5372549
0.53333336]
[0.8235294 0.7882353 0.75686276 ... 0.47058824 0.47058824
0.47058824]
...
[0.50980395 0.5176471 0.5137255 ... 0.58431375 0.5647059
0.53333336]
[0.5137255 0.53333336 0.5254902 ... 0.58431375 0.5686275
0.53333336]
[0.5176471 0.53333336 0.5294118 ... 0.5764706 0.56078434
0.5294118 ]]
[[0.95686275 0.9372549 0.90588236 ... 0.18823528 0.19999999
0.20784312]
[0.9098039 0.8784314 0.8352941 ... 0.1607843 0.17254901
0.18039215]
[0.84313726 0.7921569 0.7490196 ... 0.1372549 0.14509803
0.15294117]
...
[0.03921568 0.05490196 0.05098039 ... 0.11764705 0.09411764
0.02745098]
[0.04705882 0.07843137 0.06666666 ... 0.12156862 0.10196078
0.03529412]
[0.05098039 0.0745098 0.07843137 ... 0.12549019 0.10196078
0.04705882]]]
[[[0.30588236 0.28627452 0.24313724 ... 0.2901961 0.26666668
0.21568626]
[0.8156863 0.6666667 0.5921569 ... 0.18039215 0.23921567
0.21568626]
[0.9019608 0.83137256 0.85490197 ... 0.21960783 0.36862746
0.23921567]
...
[0.7058824 0.83137256 0.85490197 ... 0.2627451 0.24313724
0.20784312]
[0.7137255 0.84313726 0.84705883 ... 0.26666668 0.29803923
0.21568626]
[0.7254902 0.8235294 0.8392157 ... 0.2509804 0.27058825
0.2352941 ]]
[[0.24705881 0.22745097 0.19215685 ... 0.2784314 0.25490198
0.19607842]
[0.59607846 0.37254903 0.29803923 ... 0.16470587 0.22745097
0.20392156]
[0.5921569 0.4509804 0.49803922 ... 0.20784312 0.3764706
0.2352941 ]
...
[0.42352942 0.4627451 0.42352942 ... 0.23921567 0.23137254
0.19999999]
[0.45882353 0.5176471 0.35686275 ... 0.23921567 0.26666668
0.19607842]
[0.41568628 0.44313726 0.34901962 ... 0.21960783 0.23921567
0.21568626]]
[[0.23137254 0.20784312 0.1490196 ... 0.30588236 0.28627452
0.19607842]
[0.61960787 0.3764706 0.26666668 ... 0.16470587 0.24313724
0.21568626]
[0.57254905 0.43137255 0.48235294 ... 0.2235294 0.40392157
0.25882354]
...
[0.4 0.42352942 0.37254903 ... 0.25490198 0.24705881
0.21568626]
[0.43137255 0.4509804 0.29411766 ... 0.25882354 0.28235295
0.20392156]
[0.38431373 0.3529412 0.25490198 ... 0.2352941 0.25490198
0.23137254]]]
[[[0.06274509 0.09019607 0.11372548 ... 0.5803922 0.5176471
0.59607846]
[0.09411764 0.14509803 0.1372549 ... 0.5294118 0.49803922
0.5058824 ]
[0.04705882 0.09411764 0.10196078 ... 0.45882353 0.42352942
0.38431373]
...
[0.15294117 0.12941176 0.1607843 ... 0.85882354 0.8509804
0.80784315]
[0.14509803 0.10588235 0.1607843 ... 0.8666667 0.85882354
0.8 ]
[0.1490196 0.10588235 0.16470587 ... 0.827451 0.8156863
0.7921569 ]]
[[0.06666666 0.12156862 0.17647058 ... 0.59607846 0.5529412
0.6039216 ]
[0.07058823 0.10588235 0.11764705 ... 0.56078434 0.5254902
0.5372549 ]
[0.03921568 0.0745098 0.09803921 ... 0.48235294 0.4392157
0.4117647 ]
...
[0.2117647 0.14509803 0.2784314 ... 0.43137255 0.3529412
0.34117648]
[0.2235294 0.11372548 0.2509804 ... 0.4509804 0.39607844
0.2509804 ]
[0.25490198 0.12156862 0.24705881 ... 0.38039216 0.36078432
0.3254902 ]]
[[0.05490196 0.09803921 0.12549019 ... 0.46666667 0.38039216
0.45490196]
[0.06274509 0.09803921 0.10196078 ... 0.44705883 0.41568628
0.3882353 ]
[0.03921568 0.06666666 0.0862745 ... 0.3764706 0.33333334
0.28235295]
...
[0.12156862 0.14509803 0.16862744 ... 0.15686274 0.0745098
0.09411764]
[0.10588235 0.11372548 0.16862744 ... 0.25882354 0.18431371
0.05490196]
[0.12156862 0.11372548 0.17254901 ... 0.2352941 0.17254901
0.14117646]]]]
Traceback (most recent call last):
File "image_loader.py", line 51, in <module>
main()
File "image_loader.py", line 46, in main
imshow(images)
File "image_loader.py", line 38, in imshow
plt.imshow(np.transpose(npimg, (1, 2, 0, 1)))
File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 598, in transpose
return _wrapfunc(a, 'transpose', axes)
File "C:\Users\Aeryes\AppData\Local\Programs\Python\Python36\lib\site-packages\numpy\core\fromnumeric.py", line 51, in _wrapfunc
return getattr(obj, method)(*args, **kwds)
ValueError: repeated axis in transpose
I tried to print out the arrays to get the dimensions but I do not know what to make of this. It is very confusing.
Here is my direct question: How do I view the input images before training using the tensors in my DataLoader object?
DataLoader doesn't convert it into Tensor automatically.
First of all, dataloader
output 4 dimensional tensor - [batch, channel, height, width]
. Matplotlib and other image processing libraries often requires [height, width, channel]
. You are right about using the transpose, just not in the right way.
There will be a lot of images in your images
so first you need to pick one (or write a for loop to save all of them). This will be simply images[i]
, typically I use i=0
.
Then, your transpose should convert a now [channel, height, width]
tensor to a [height, width, channel]
one. To do this, use np.transpose(image.numpy(), (1, 2, 0))
, very much like yours.
Putting them together, you should have
plt.imshow(np.transpose(images[0].numpy(), (1, 2, 0)))
Sometimes you need to call .detach()
(detach this part from the computational graph) and .cpu()
(transfer data from GPU to CPU) depending on the use case, that will be
plt.imshow(np.transpose(images[0].cpu().detach().numpy(), (1, 2, 0)))
This did the trick for me when I faced the same problem. Pytorch dataset behaves similar to a regular list as far as numpy is concerned and hence this works.
train_np = np.array(train_loader.dataset)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With