Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reading MNIST Database

I am currently exploring neural networks and machine learning and I implemented a basic neural network in c#. Now I wanted to test my back propagation training algorithm with the MNIST database. Although I am having serious trouble reading the files correctly.

Spoiler the code is currently very badly optimised for performance. My aim currently is to grasp the subject and get a structured view how things work before I start throwing out my data structures for faster ones.

To train the network I want to feed it a custom TrainingSet data structure:

[Serializable]
public class TrainingSet
{
    public Dictionary<List<double>, List<double>> data = new Dictionary<List<double>, List<double>>();
}

Keys will be my input data (784 pixels per entry(image) which will represent the greyscale values in range from 0 to 1). Values will be my output data (10 entries representing the digits from 0-9 with all entries on 0 except the exspected one at 1)

Now I want to read the MNIST database according to this contract. I am currentl on my 2nd try which is inspired by this blogpost: https://jamesmccaffrey.wordpress.com/2013/11/23/reading-the-mnist-data-set-with-c/ . Sadly it is still producing the same nonsense as my first try scattering the pixels in a strange pattern: Pattern screenshot

My current reading algorithm:

    public static TrainingSet GenerateTrainingSet(FileInfo imagesFile, FileInfo labelsFile)
    {
        MnistImageView imageView = new MnistImageView();
        imageView.Show();

        TrainingSet trainingSet = new TrainingSet();

        List<List<double>> labels = new List<List<double>>();
        List<List<double>> images = new List<List<double>>();

        using (BinaryReader brLabels = new BinaryReader(new FileStream(labelsFile.FullName, FileMode.Open)))
        {
            using (BinaryReader brImages = new BinaryReader(new FileStream(imagesFile.FullName, FileMode.Open)))
            {
                int magic1 = brImages.ReadBigInt32(); //Reading as BigEndian
                int numImages = brImages.ReadBigInt32();
                int numRows = brImages.ReadBigInt32();
                int numCols = brImages.ReadBigInt32();

                int magic2 = brLabels.ReadBigInt32();
                int numLabels = brLabels.ReadBigInt32();

                byte[] pixels = new byte[numRows * numCols];

                // each image
                for (int imageCounter = 0; imageCounter < numImages; imageCounter++)
                {
                    List<double> imageInput = new List<double>();
                    List<double> exspectedOutput = new List<double>();

                    for (int i = 0; i < 10; i++) //generate empty exspected output
                        exspectedOutput.Add(0);

                    //read image
                    for (int p = 0; p < pixels.Length; p++)
                    {
                        byte b = brImages.ReadByte();
                        pixels[p] = b;

                        imageInput.Add(b / 255.0f); //scale in 0 to 1 range
                    }

                    //read label
                    byte lbl = brLabels.ReadByte();
                    exspectedOutput[lbl] = 1; //modify exspected output

                    labels.Add(exspectedOutput);
                    images.Add(imageInput);

                    //Debug view showing parsed image.......................
                    Bitmap image = new Bitmap(numCols, numRows);

                    for (int y = 0; y < numRows; y++)
                    {
                        for (int x = 0; x < numCols; x++)
                        {
                            image.SetPixel(x, y, Color.FromArgb(255 - pixels[x * y], 255 - pixels[x * y], 255 - pixels[x * y])); //invert colors to have 0,0,0 be white as specified by mnist
                        }
                    }

                    imageView.SetImage(image);
                    imageView.Refresh();
                    //.......................................................
                }

                brImages.Close();
                brLabels.Close();
            }
        }

        for (int i = 0; i < images.Count; i++)
        {
            trainingSet.data.Add(images[i], labels[i]);
        }

        return trainingSet;
    }

All images produce a pattern as shown above. It's never the exact same pattern but always seems to have the pixels "pulled" down to the right corner.

like image 386
Robin B Avatar asked Mar 21 '18 13:03

Robin B


People also ask

What can I learn about the MNIST database?

The MNIST database (Modified National Institute of Standards and Technology database) is a large database of handwritten digits that is commonly used for training various image processing systems. The database is also widely used for training and testing in the field of machine learning.


2 Answers

That is how I did it:

public static class MnistReader
{
    private const string TrainImages = "mnist/train-images.idx3-ubyte";
    private const string TrainLabels = "mnist/train-labels.idx1-ubyte";
    private const string TestImages = "mnist/t10k-images.idx3-ubyte";
    private const string TestLabels = "mnist/t10k-labels.idx1-ubyte";

    public static IEnumerable<Image> ReadTrainingData()
    {
        foreach (var item in Read(TrainImages, TrainLabels))
        {
            yield return item;
        }
    }

    public static IEnumerable<Image> ReadTestData()
    {
        foreach (var item in Read(TestImages, TestLabels))
        {
            yield return item;
        }
    }

    private static IEnumerable<Image> Read(string imagesPath, string labelsPath)
    {
        BinaryReader labels = new BinaryReader(new FileStream(labelsPath, FileMode.Open));
        BinaryReader images = new BinaryReader(new FileStream(imagesPath, FileMode.Open));

        int magicNumber = images.ReadBigInt32();
        int numberOfImages = images.ReadBigInt32();
        int width = images.ReadBigInt32();
        int height = images.ReadBigInt32();

        int magicLabel = labels.ReadBigInt32();
        int numberOfLabels = labels.ReadBigInt32();

        for (int i = 0; i < numberOfImages; i++)
        {
            var bytes = images.ReadBytes(width * height);
            var arr = new byte[height, width];

            arr.ForEach((j,k) => arr[j, k] = bytes[j * height + k]);

            yield return new Image()
            {
                Data = arr,
                Label = labels.ReadByte()
            };
        }
    }
}

Image class:

public class Image
{
    public byte Label { get; set; }
    public byte[,] Data { get; set; }
}

Some extension methods:

public static class Extensions
{
    public static int ReadBigInt32(this BinaryReader br)
    {
        var bytes = br.ReadBytes(sizeof(Int32));
        if (BitConverter.IsLittleEndian) Array.Reverse(bytes);
        return BitConverter.ToInt32(bytes, 0);
    }

    public static void ForEach<T>(this T[,] source, Action<int, int> action)
    {
        for (int w = 0; w < source.GetLength(0); w++)
        {
            for (int h = 0; h < source.GetLength(1); h++)
            {
                action(w, h);
            }
        }
    }
}

Usage:

foreach (var image in MnistReader.ReadTrainingData())
{
    //use image here     
}

or

foreach (var image in MnistReader.ReadTestData())
{
    //use image here     
}
like image 107
koryakinp Avatar answered Oct 25 '22 14:10

koryakinp


Why not use a nuget package:

  • MNIST.IO Just a datareader (disclaimer: my package)
  • Accord.DataSets Contains classes to download and parse machine learning datasets such as MNIST, News20, Iris. This package is part of the Accord.NET Framework.
like image 3
Guy Langston Avatar answered Oct 25 '22 12:10

Guy Langston