Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

MNIST data download from sklearn datasets gives Timeout error

I am new to ML and trying to download MNIST data. The code I am using is:

from sklearn.datasets import fetch_mldata
mnist = fetch_mldata('MNIST original')

But, it gives an error saying:

TimeoutError: [WinError 10060] A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond

Can anyone please help me what needs to be done to rectify this issue?

like image 389
Swaraj Shekhar Avatar asked Nov 01 '18 07:11

Swaraj Shekhar


2 Answers

here is the issue and some workaround good people suggested :

https://github.com/scikit-learn/scikit-learn/issues/8588

easiest one was to download .mat file of MNIST with this download link:

download MNIST.mat

after download put the file inside ~/scikit_learn_data/mldata folder, if this folder doesn't exist create it and put the Mnist.mat inside it. when you have them locally scikit learn won't download it and uses that file.

like image 127
leo Avatar answered Nov 01 '22 05:11

leo


Since fetch_mldata had been deprecated, we will have to move to fetch_openml. Make sure to update your scikit-learn to version 0.20.0 or up in order to get the openml work.

  1. openml currently has 5 different datasets related to MNIST dataset. Here is one example from sklearn's document using the mnist-784 dataset.
from sklearn.datasets import fetch_openml
# Load data from https://www.openml.org/d/554
X, y = fetch_openml('mnist_784', version=1, return_X_y=True)
  1. Or if you don't need a very large dataset, you can use load_digits:
from sklearn.datasets  import load_digits
mnist = load_digits()

Note that if you are following the book Hands-On Machine Learning with Scikit-Learn and TensorFlow, with mnist-784 dataset, you may notice that the code

some_digit = X[36000]
some_digit_image = some_digit.reshape(28, 28)
plt.imshow(some_digit_image, cmap=matplotlib.cm.binary, interpolation="nearest")
plt.axis('off')
plt.show()

returns a picture of 9 instead of 5. I guess, it could either be that the mnist-784 and the mnist original are two subsets of the nist data, or the order of data is different between the two datasets.

PS: I had encountered some error about ssl when I was trying to load data, in my case I update openssl and the problem had been resolved.

like image 45
Madmint Avatar answered Nov 01 '22 05:11

Madmint