Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Plot colour image from a numpy array that has 3 channels

in my Jupyter notebook I am trying to display an image that I am iterating on through Keras. The code I am using is as below

def plotImages(path, num):
 batchGenerator = file_utils.fileBatchGenerator(path+"train/", num)
 imgs,labels = next(batchGenerator)
 fig = plt.figure(figsize=(224, 224))
 plt.gray()
 for i in range(num):
    sub = fig.add_subplot(num, 1, i + 1)
    sub.imshow(imgs[i,0], interpolation='nearest')

But this only plots single channel, so my image is grayscale. How do I use the 3 channels to output a colour image plot. ?

like image 355
Abhik Avatar asked Jan 25 '17 16:01

Abhik


1 Answers

If you want to display an RGB image, you have to supply all three channels. Based on your code, you are instead displaying just the first channel so matplotlib has no information to display it as RGB. Instead it will map the values to the gray colormap since you've called plt.gray()

Instead, you'll want to pass all channels of the RGB image to imshow and then the true color display is used and the colormap of the figure is disregarded

sub.imshow(imgs, interpolation='nearest')

Update

Since imgs is actually 2 x 3 x 224 x 224, you'll want to index into imgs and permute the dimensions to be 224 x 224 x 3 prior to displaying the image

im2display = imgs[1].transpose((1,2,0))
sub.imshow(im2display, interpolation='nearest')
like image 101
Suever Avatar answered Nov 15 '22 16:11

Suever