Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Reading a IDX file type in Java

I have built a image classifier in Java that I would like to test against the images provided here: http://yann.lecun.com/exdb/mnist/

Unfortunately, if you download the train-images-idx3-ubyte.gz or any of the other 3 files, they are all of file type: .idx1-ubyte

First Question: I was wondering if anyone can give me instructions on how to make the .idx1-ubyte into bitmaps (.bmp) files?

Second Question: Or just how I can read these files in general?

Information about the IDX file format: the IDX file format is a simple format for vectors and multidimensional matrices of various numerical types. The basic format is:

magic number 
size in dimension 0 
size in dimension 1 
size in dimension 2 
..... 
size in dimension N 
data

The magic number is an integer (MSB first). The first 2 bytes are always 0.

The third byte codes the type of the data:

0x08: unsigned byte 
0x09: signed byte 
0x0B: short (2 bytes) 
0x0C: int (4 bytes) 
0x0D: float (4 bytes) 
0x0E: double (8 bytes)

The 4-th byte codes the number of dimensions of the vector/matrix: 1 for vectors, 2 for matrices....

The sizes in each dimension are 4-byte integers (MSB first, high endian, like in most non-Intel processors).

The data is stored like in a C array, i.e. the index in the last dimension changes the fastest.

like image 426
Wang-Zhao-Liu Q Avatar asked Jun 24 '13 15:06

Wang-Zhao-Liu Q


2 Answers

Pretty Straightforward, as WPrecht said: "The URL describes the format you have to decode". This is my ImageSet exporter for the idx file, not very clean, but does what it has to do.

public class IdxReader {

    public static void main(String[] args) {
        // TODO Auto-generated method stub
        FileInputStream inImage = null;
        FileInputStream inLabel = null;

        String inputImagePath = "CBIR_Project/imagesRaw/MNIST/train-images-idx3-ubyte";
        String inputLabelPath = "CBIR_Project/imagesRaw/MNIST/train-labels-idx1-ubyte";

        String outputPath = "CBIR_Project/images/MNIST_Database_ARGB/";

        int[] hashMap = new int[10]; 

        try {
            inImage = new FileInputStream(inputImagePath);
            inLabel = new FileInputStream(inputLabelPath);

            int magicNumberImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfImages = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfRows  = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());
            int numberOfColumns = (inImage.read() << 24) | (inImage.read() << 16) | (inImage.read() << 8) | (inImage.read());

            int magicNumberLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());
            int numberOfLabels = (inLabel.read() << 24) | (inLabel.read() << 16) | (inLabel.read() << 8) | (inLabel.read());

            BufferedImage image = new BufferedImage(numberOfColumns, numberOfRows, BufferedImage.TYPE_INT_ARGB);
            int numberOfPixels = numberOfRows * numberOfColumns;
            int[] imgPixels = new int[numberOfPixels];

            for(int i = 0; i < numberOfImages; i++) {

                if(i % 100 == 0) {System.out.println("Number of images extracted: " + i);}

                for(int p = 0; p < numberOfPixels; p++) {
                    int gray = 255 - inImage.read();
                    imgPixels[p] = 0xFF000000 | (gray<<16) | (gray<<8) | gray;
                }

                image.setRGB(0, 0, numberOfColumns, numberOfRows, imgPixels, 0, numberOfColumns);

                int label = inLabel.read();

                hashMap[label]++;
                File outputfile = new File(outputPath + label + "_0" + hashMap[label] + ".png");

                ImageIO.write(image, "png", outputfile);
            }

        } catch (FileNotFoundException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } catch (IOException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        } finally {
            if (inImage != null) {
                try {
                    inImage.close();
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
            if (inLabel != null) {
                try {
                    inLabel.close();
                } catch (IOException e) {
                    // TODO Auto-generated catch block
                    e.printStackTrace();
                }
            }
        }
    }

}
like image 175
RayDeeA Avatar answered Oct 20 '22 02:10

RayDeeA


I created some classes for reading the MNIST handwritten digits data set with Java. The classes can read the files after they have been decompressed (unzipped) from the files that are available at the download site. Classes that allow reading the original (compressed) files are part of a small MnistReader project.

These following classes are standalone (meaning that they do not have dependencies to third-party libraries) and are essentially in the Public Domain - meaning that they can just be copied into own projects. (Attributions would be appreciated, but not required) :

The MnistDecompressedReader class:

import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Path;
import java.util.Objects;
import java.util.function.Consumer;

/**
 * A class for reading the MNIST data set from the <b>decompressed</b> 
 * (unzipped) files that are published at
 * <a href="http://yann.lecun.com/exdb/mnist/">
 * http://yann.lecun.com/exdb/mnist/</a>. 
 */
public class MnistDecompressedReader
{
    /**
     * Default constructor
     */
    public MnistDecompressedReader()
    {
        // Default constructor
    }

    /**
     * Read the MNIST training data from the given directory. The data is 
     * assumed to be located in files with their default names,
     * <b>decompressed</b> from the original files: 
     * extension) : 
     * <code>train-images.idx3-ubyte</code> and
     * <code>train-labels.idx1-ubyte</code>.
     * 
     * @param inputDirectoryPath The input directory
     * @param consumer The consumer that will receive the resulting 
     * {@link MnistEntry} instances
     * @throws IOException If an IO error occurs
     */
    public void readDecompressedTraining(Path inputDirectoryPath, 
        Consumer<? super MnistEntry> consumer) throws IOException
    {
        String trainImagesFileName = "train-images.idx3-ubyte";
        String trainLabelsFileName = "train-labels.idx1-ubyte";
        Path imagesFilePath = inputDirectoryPath.resolve(trainImagesFileName);
        Path labelsFilePath = inputDirectoryPath.resolve(trainLabelsFileName);
        readDecompressed(imagesFilePath, labelsFilePath, consumer);
    }

    /**
     * Read the MNIST training data from the given directory. The data is 
     * assumed to be located in files with their default names,
     * <b>decompressed</b> from the original files: 
     * extension) : 
     * <code>t10k-images.idx3-ubyte</code> and
     * <code>t10k-labels.idx1-ubyte</code>.
     * 
     * @param inputDirectoryPath The input directory
     * @param consumer The consumer that will receive the resulting 
     * {@link MnistEntry} instances
     * @throws IOException If an IO error occurs
     */
    public void readDecompressedTesting(Path inputDirectoryPath, 
        Consumer<? super MnistEntry> consumer) throws IOException
    {
        String testImagesFileName = "t10k-images.idx3-ubyte";
        String testLabelsFileName = "t10k-labels.idx1-ubyte";
        Path imagesFilePath = inputDirectoryPath.resolve(testImagesFileName);
        Path labelsFilePath = inputDirectoryPath.resolve(testLabelsFileName);
        readDecompressed(imagesFilePath, labelsFilePath, consumer);
    }


    /**
     * Read the MNIST data from the specified (decompressed) files.
     * 
     * @param imagesFilePath The path of the images file
     * @param labelsFilePath The path of the labels file
     * @param consumer The consumer that will receive the resulting 
     * {@link MnistEntry} instances
     * @throws IOException If an IO error occurs
     */
    public void readDecompressed(Path imagesFilePath, Path labelsFilePath, 
        Consumer<? super MnistEntry> consumer) throws IOException
    {
        try (InputStream decompressedImagesInputStream = 
            new FileInputStream(imagesFilePath.toFile());
            InputStream decompressedLabelsInputStream = 
                new FileInputStream(labelsFilePath.toFile()))
        {
            readDecompressed(
                decompressedImagesInputStream, 
                decompressedLabelsInputStream, 
                consumer);
        }
    }

    /**
     * Read the MNIST data from the given (decompressed) input streams.
     * The caller is responsible for closing the given streams.
     * 
     * @param decompressedImagesInputStream The decompressed input stream
     * containing the image data 
     * @param decompressedLabelsInputStream The decompressed input stream
     * containing the label data
     * @param consumer The consumer that will receive the resulting 
     * {@link MnistEntry} instances
     * @throws IOException If an IO error occurs
     */
    public void readDecompressed(
        InputStream decompressedImagesInputStream, 
        InputStream decompressedLabelsInputStream, 
        Consumer<? super MnistEntry> consumer) throws IOException
    {
        Objects.requireNonNull(consumer, "The consumer may not be null");

        DataInputStream imagesDataInputStream = 
            new DataInputStream(decompressedImagesInputStream);
        DataInputStream labelsDataInputStream = 
            new DataInputStream(decompressedLabelsInputStream);

        int magicImages = imagesDataInputStream.readInt();
        if (magicImages != 0x803)
        {
            throw new IOException("Expected magic header of 0x803 "
                + "for images, but found " + magicImages);
        }

        int magicLabels = labelsDataInputStream.readInt();
        if (magicLabels != 0x801)
        {
            throw new IOException("Expected magic header of 0x801 "
                + "for labels, but found " + magicLabels);
        }

        int numberOfImages = imagesDataInputStream.readInt();
        int numberOfLabels = labelsDataInputStream.readInt();

        if (numberOfImages != numberOfLabels)
        {
            throw new IOException("Found " + numberOfImages 
                + " images but " + numberOfLabels + " labels");
        }

        int numRows = imagesDataInputStream.readInt();
        int numCols = imagesDataInputStream.readInt();

        for (int n = 0; n < numberOfImages; n++)
        {
            byte label = labelsDataInputStream.readByte();
            byte imageData[] = new byte[numRows * numCols];
            read(imagesDataInputStream, imageData);

            MnistEntry mnistEntry = new MnistEntry(
                n, label, numRows, numCols, imageData);
            consumer.accept(mnistEntry);
        }
    }

    /**
     * Read bytes from the given input stream, filling the given array
     * 
     * @param inputStream The input stream
     * @param data The array to be filled
     * @throws IOException If the input stream does not contain enough bytes
     * to fill the array, or any other IO error occurs
     */
    private static void read(InputStream inputStream, byte data[]) 
        throws IOException
    {
        int offset = 0;
        while (true)
        {
            int read = inputStream.read(
                data, offset, data.length - offset);
            if (read < 0)
            {
                break;
            }
            offset += read;
            if (offset == data.length)
            {
                return;
            }
        }
        throw new IOException("Tried to read " + data.length
            + " bytes, but only found " + offset);
    }
}

The MnistEntry class:

import java.awt.image.BufferedImage;
import java.awt.image.DataBuffer;
import java.awt.image.DataBufferByte;

/**
 * An entry of the MNIST data set. Instances of this class will be passed
 * to the consumer that is given to the {@link MnistCompressedReader} and
 * {@link MnistDecompressedReader} reading methods.
 */
public class MnistEntry
{
    /**
     * The index of the entry
     */
    private final int index;

    /**
     * The class label of the entry
     */
    private final byte label;

    /**
     * The number of rows of the image data
     */
    private final int numRows;

    /**
     * The number of columns of the image data
     */
    private final int numCols;

    /**
     * The image data 
     */
    private final byte[] imageData;        

    /**
     * Default constructor
     * 
     * @param index The index
     * @param label The label
     * @param numRows The number of rows
     * @param numCols The number of columns
     * @param imageData The image data
     */
    MnistEntry(int index, byte label, int numRows, int numCols,
        byte[] imageData)
    {
        this.index = index;
        this.label = label;
        this.numRows = numRows;
        this.numCols = numCols;
        this.imageData = imageData;
    }

    /**
     * Returns the index of the entry
     * 
     * @return The index
     */
    public int getIndex()
    {
        return index;
    }

    /**
     * Returns the class label of the entry. This is a value in [0,9], 
     * indicating which digit is shown in the entry
     * 
     * @return The class label
     */
    public byte getLabel()
    {
        return label;
    }

    /**
     * Returns the number of rows of the image data. 
     * This will usually be 28.
     * 
     * @return The number of rows
     */
    public int getNumRows()
    {
        return numRows;
    }

    /**
     * Returns the number of columns of the image data. 
     * This will usually be 28.
     * 
     * @return The number of columns
     */
    public int getNumCols()
    {
        return numCols;
    }

    /**
     * Returns a <i>reference</i> to the image data. This will be an array
     * of length <code>numRows * numCols</code>, containing values 
     * in [0,255] indicating the brightness of the pixels.
     * 
     * @return The image data
     */
    public byte[] getImageData()
    {
        return imageData;
    }

    /**
     * Creates a new buffered image from the image data that is stored
     * in this entry.
     * 
     * @return The image
     */
    public BufferedImage createImage()
    {
        BufferedImage image = new BufferedImage(getNumCols(),
            getNumRows(), BufferedImage.TYPE_BYTE_GRAY);
        DataBuffer dataBuffer = image.getRaster().getDataBuffer();
        DataBufferByte dataBufferByte = (DataBufferByte) dataBuffer;
        byte data[] = dataBufferByte.getData();
        System.arraycopy(getImageData(), 0, data, 0, data.length);
        return image;
    }


    @Override
    public String toString()
    {
        String indexString = String.format("%05d", index);
        return "MnistEntry[" 
        + "index=" + indexString + "," 
        + "label=" + label + "]";
    }

}

The reader can be used to read the uncompressed files. The result will be MnistEntry instances that are passed to a consumer:

MnistDecompressedReader mnistReader = new MnistDecompressedReader();
mnistReader.readDecompressedTraining(Paths.get("./data"), mnistEntry -> 
{
    System.out.println("Read entry " + mnistEntry);
    BufferedImage image = mnistEntry.createImage();
    ...
});

The MnistReader project contains several examples of how these classes may be used to read the compressed- or uncompressed data, or to generate PNG images from the MNIST entries.

like image 21
Marco13 Avatar answered Oct 20 '22 00:10

Marco13