How do I display a PyTorch Tensor
of shape (3, 224, 224)
representing a 224x224 RGB image?
Using plt.imshow(image)
gives the error:
TypeError: Invalid dimensions for image data
To display the image, first we convert the image tensor to a PIL image and then display the image.
view(-1, Dnew) it would produce a tensor of two dimensions/indices but would make sure the first dimension to be of the correct size according to the original dimension of the tensor.
We can crop an image in PyTorch by using the CenterCrop() method. This method accepts images like PIL Image, Tensor Image, and a batch of Tensor images. The tensor image is a PyTorch tensor with [C, H, W] shape, where C represents a number of channels and H, W represents height and width respectively.
A complete example given an image pathname img_path
:
from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")
Note that transforms.*
return a class, which is why the funky bracketing.
As you can see matplotlib
works fine even without conversion to numpy
array. But PyTorch Tensors ("Image tensors") are channel first, so to use them with matplotlib
you need to reshape it:
Code:
from scipy.misc import face
import matplotlib.pyplot as plt
import torch
np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)
# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)
# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)
plt.imshow(tensor_image)
plt.show()
Output:
<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
PyTorch modules processing image data expect tensors in the format C × H × W.1
Whereas PILLow and Matplotlib expect image arrays in the format H × W × C.2
You can easily convert tensors to/from this format with a TorchVision transform:
from torchvision import transforms.functional as F
F.to_pil_image(image_tensor)
Or by directly permuting the axes:
image_tensor.permute(1,2,0)
PyTorch modules dealing with image data require tensors to be laid out as C × H × W : channels, height, and width, respectively.
Note how we have to use
permute
to change the order of the axes from C × H × W to H × W × C to match what Matplotlib expects.
- Deep Learning with PyTorch
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With