Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Making your own set of MNIST data (identical to MNIST format)

I'm trying to create my own version of MNIST data. I've converted my training and testing data to the following files;

test-images-idx3-ubyte.gz
test-labels-idx1-ubyte.gz
train-images-idx3-ubyte.gz
train-labels-idx1-ubyte.gz

(For anyone interested I did this using JPG-PNG-to-MNIST-NN-Format which seems to get me close to what I'm aiming for.)

However this isn't quite the same as the file type and format of the MNIST data (mnist.pkl.gz). I understand that pkl means the data has been pickled, but I don't really understand the process of pickling the data - is there a specific order to the pickling? Can someone provide code that I should use to pickle my data?

like image 421
user6916458 Avatar asked Oct 03 '17 23:10

user6916458


1 Answers

import gzip
import os

import numpy as np
import six
from six.moves.urllib import request

parent = 'http://yann.lecun.com/exdb/mnist'
train_images = 'train-images-idx3-ubyte.gz'
train_labels = 'train-labels-idx1-ubyte.gz'
test_images = 't10k-images-idx3-ubyte.gz'
test_labels = 't10k-labels-idx1-ubyte.gz'
num_train = 17010
num_test = 3010
dim = 32*32


def load_mnist(images, labels, num):
    data = np.zeros(num * dim, dtype=np.uint8).reshape((num, dim))
    target = np.zeros(num, dtype=np.uint8).reshape((num, ))

    with gzip.open(images, 'rb') as f_images,\
            gzip.open(labels, 'rb') as f_labels:
        f_images.read(16)
        f_labels.read(8)
        for i in six.moves.range(num):
            target[i] = ord(f_labels.read(1))
            for j in six.moves.range(dim):
                data[i, j] = ord(f_images.read(1))

    return data, target


def download_mnist_data():

    print('Converting training data...')
    data_train, target_train = load_mnist(train_images, train_labels,
                                          num_train)
    print('Done')
    print('Converting test data...')
    data_test, target_test = load_mnist(test_images, test_labels, num_test)
    mnist = {}
    mnist['data'] = np.append(data_train, data_test, axis=0)
    mnist['target'] = np.append(target_train, target_test, axis=0)

    print('Done')
    print('Save output...')
    with open('mnist.pkl', 'wb') as output:
        six.moves.cPickle.dump(mnist, output, -1)
    print('Done')
    print('Convert completed')


def load_mnist_data():
    if not os.path.exists('mnist.pkl'):
        download_mnist_data()
    with open('mnist.pkl', 'rb') as mnist_pickle:
        mnist = six.moves.cPickle.load(mnist_pickle)
    return mnist
download_mnist_data()
like image 103
Shivam Chawla Avatar answered Nov 14 '22 22:11

Shivam Chawla