Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I display a single image in PyTorch?

How do I display a PyTorch Tensor of shape (3, 224, 224) representing a 224x224 RGB image? Using plt.imshow(image) gives the error:

TypeError: Invalid dimensions for image data

like image 238
Tom Hale Avatar asked Dec 05 '18 00:12

Tom Hale


People also ask

How do I display an image in PyTorch?

To display the image, first we convert the image tensor to a PIL image and then display the image.

What does view (- 1 do in PyTorch?

view(-1, Dnew) it would produce a tensor of two dimensions/indices but would make sure the first dimension to be of the correct size according to the original dimension of the tensor.

How do I crop an image in PyTorch?

We can crop an image in PyTorch by using the CenterCrop() method. This method accepts images like PIL Image, Tensor Image, and a batch of Tensor images. The tensor image is a PyTorch tensor with [C, H, W] shape, where C represents a number of channels and H, W represents height and width respectively.


3 Answers

A complete example given an image pathname img_path:

from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")

Note that transforms.* return a class, which is why the funky bracketing.

like image 52
Tom Hale Avatar answered Oct 16 '22 21:10

Tom Hale


As you can see matplotlib works fine even without conversion to numpy array. But PyTorch Tensors ("Image tensors") are channel first, so to use them with matplotlib you need to reshape it:

Code:

from scipy.misc import face
import matplotlib.pyplot as plt
import torch

np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)

# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)

# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)

plt.imshow(tensor_image)
plt.show()

Output:

<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
like image 27
trsvchn Avatar answered Oct 16 '22 19:10

trsvchn


PyTorch modules processing image data expect tensors in the format C × H × W.1
Whereas PILLow and Matplotlib expect image arrays in the format H × W × C.2

You can easily convert tensors to/from this format with a TorchVision transform:

from torchvision import transforms.functional as F

F.to_pil_image(image_tensor)

Or by directly permuting the axes:

image_tensor.permute(1,2,0)

  1. PyTorch modules dealing with image data require tensors to be laid out as C × H × W : channels, height, and width, respectively.

  2. Note how we have to use permute to change the order of the axes from C × H × W to H × W × C to match what Matplotlib expects.

    • Deep Learning with PyTorch
like image 6
iacob Avatar answered Oct 16 '22 21:10

iacob