Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scikit-learn SVM digit recognition

I want to make a program to recognize the digit in an image. I follow the tutorial in scikit learn .

I can train and fit the svm classifier like the following.

First, I import the libraries and dataset

from sklearn import datasets, svm, metrics

digits = datasets.load_digits()
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

Second, I create the SVM model and train it with the dataset.

classifier = svm.SVC(gamma = 0.001)
classifier.fit(data[:n_samples], digits.target[:n_samples])

And then, I try to read my own image and use the function predict() to recognize the digit.

Here is my image: enter image description here

I reshape the image into (8, 8) and then convert it to a 1D array.

img = misc.imread("w1.jpg")
img = misc.imresize(img, (8, 8))
img = img[:, :, 0]

Finally, when I print out the prediction, it returns [1]

predicted = classifier.predict(img.reshape((1,img.shape[0]*img.shape[1] )))
print predicted

Whatever I user others images, it still returns [1]

enter image description here enter image description here

When I print out the "default" dataset of number "9", it looks like:enter image description here

My image number "9" :

enter image description here

You can see the non-zero number is quite large for my image.

I dont know why. I am looking for help to solve my problem. Thanks

like image 957
VICTOR Avatar asked Jul 22 '16 06:07

VICTOR


People also ask

Is svm and SVC the same?

Support Vector Machine (SVM) – (Interval block): And that's the difference between SVM and SVC. If the hyperplane classifies the dataset linearly then the algorithm we call it as SVC and the algorithm that separates the dataset by non-linear approach then we call it as SVM. SVM has a technique called the kernel trick.

What is SVC in sklearn?

SVC, or Support Vector Classifier, is a supervised machine learning algorithm typically used for classification tasks. SVC works by mapping data points to a high-dimensional space and then finding the optimal hyperplane that divides the data into two classes.


2 Answers

My best bet would be that there is a problem with your data types and array shapes.

It looks like you are training on numpy arrays that are of the type np.float64 (or possibly np.float32 on 32 bit systems, I don't remember) and where each image has the shape (64,).

Meanwhile your input image for prediction, after the resizing operation in your code, is of type uint8 and shape (1, 64).

I would first try changing the shape of your input image since dtype conversions often just work as you would expect. So change this line:

predicted = classifier.predict(img.reshape((1,img.shape[0]*img.shape[1] )))

to this:

predicted = classifier.predict(img.reshape(img.shape[0]*img.shape[1]))

If that doesn't fix it, you can always try recasting the data type as well with

img = img.astype(digits.images.dtype).

I hope that helps. Debugging by proxy is a lot harder than actually sitting in front of your computer :)

Edit: According to the SciPy documentation, the training data contains integer values from 0 to 16. The values in your input image should be scaled to fit the same interval. (http://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html#sklearn.datasets.load_digits)

like image 119
Dr K Avatar answered Oct 06 '22 21:10

Dr K


1) You need to create your own training set - based on data similar to what you will be making predictions. The call to datasets.load_digits() in scikit-learn is loading a preprocessed version of the MNIST Digits dataset, which, for all we know, could have very different images to the ones that you are trying to recognise.

2) You need to set the parameters of your classifier properly. The call to svm.SVC(gamma = 0.001) is just choosing an arbitrary value of the gamma parameter in SVC, which may not be the best option. In addition, you are not configuring the C parameter - which is pretty important for SVMs. I'd bet that this is one of the reasons why your output is 'always 1'.

3) Whatever final settings you choose for your model, you'll need to use a cross-validation scheme to ensure that the algorithm is effectively learning

There's a lot of Machine Learning theory behind this, but, as a good start, I would really recommend to have a look at SVM - scikit-learn for a more in-depth description of how the SVC implementation in sickit-learn works, and GridSearchCV for a simple technique for parameter setting.

like image 20
carrdelling Avatar answered Oct 06 '22 21:10

carrdelling