Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Visualize MNIST dataset using OpenCV or Matplotlib/Pyplot

i have MNIST dataset and i am trying to visualise it using pyplot. The dataset is in cvs format where each row is one image of 784 pixels. i want to visualise it in pyplot or opencv in the 28*28 image format. I am trying directly using :

plt.imshow(X[2:],cmap =plt.cm.gray_r, interpolation = "nearest") 

but i its not working? any ideas on how should i approach this.

like image 501
decipher Avatar asked May 14 '16 15:05

decipher


People also ask

How do I visualize MNIST data?

To plot an individual MNIST image, we will first store the individual image in an “image” variable. You can pass this variable to the imshow method as shown below. Next, we will initialize the figure and axes handles using matplotlib's subplots command, then iteratively display the digit images and labels.


1 Answers

Assuming you have a CSV file with this format, which is a format the MNIST dataset is available in

label, pixel_1_1, pixel_1_2, ...

Here's how you can visulize it in Python with Matplotlib and then OpenCV

Matplotlib / Pyplot

import numpy as np
import csv
import matplotlib.pyplot as plt

with open('mnist_test_10.csv', 'r') as csv_file:
    for data in csv.reader(csv_file):
        # The first column is the label
        label = data[0]

        # The rest of columns are pixels
        pixels = data[1:]

        # Make those columns into a array of 8-bits pixels
        # This array will be of 1D with length 784
        # The pixel intensity values are integers from 0 to 255
        pixels = np.array(pixels, dtype='uint8')

        # Reshape the array into 28 x 28 array (2-dimensional array)
        pixels = pixels.reshape((28, 28))

        # Plot
        plt.title('Label is {label}'.format(label=label))
        plt.imshow(pixels, cmap='gray')
        plt.show()

        break # This stops the loop, I just want to see one

enter image description here

OpenCV

You can take the pixels numpy array from above which is of dtype='uint8' (unsigned 8-bits integer) and shape 28 x 28 , and plot with cv2.imshow()

    title = 'Label is {label}'.format(label=label)

    cv2.imshow(title, pixels)
    cv2.waitKey(0)
    cv2.destroyAllWindows()
like image 71
bakkal Avatar answered Sep 30 '22 13:09

bakkal