Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to compute the probability of a multi-class prediction using libsvm?

I'm using libsvm and the documentation leads me to believe that there's a way to output the believed probability of an output classification's accuracy. Is this so? And if so, can anyone provide a clear example of how to do it in code?

Currently, I'm using the Java libraries in the following manner

    SvmModel model = Svm.svm_train(problem, parameters);
    SvmNode x[] = getAnArrayOfSvmNodesForProblem();
    double predictedValue = Svm.svm_predict(model, x);
like image 515
Cuga Avatar asked Jan 22 '23 03:01

Cuga


2 Answers

Given your code-snippet, I'm going to assume you want to use the Java API packaged with libSVM, rather than the more verbose one provided by jlibsvm.

To enable prediction with probability estimates, train a model with the svm_parameter field probability set to 1. Then, just change your code so that it calls the svm method svm_predict_probability rather than svm_predict.

Modifying your snippet, we have:

parameters.probability = 1;
svm_model model = svm.svm_train(problem, parameters);

svm_node x[] = problem.x[0]; // let's try the first data pt in problem
double[] prob_estimates = new double[NUM_LABEL_CLASSES]; 
svm.svm_predict_probability(model, x, prob_estimates);

It's worth knowing that training with multiclass probability estimates can change the predictions made by the classifier. For more on this, see the question Calculating Nearest Match to Mean/Stddev Pair With LibSVM.

like image 176
dmcer Avatar answered Jan 24 '23 16:01

dmcer


The accepted answer worked like a charm. Make sure to set probability = 1 during training.

If you are trying to drop prediction when the confidence is not met with threshold, here is the code sample:

double confidenceScores[] = new double[model.nr_class];
svm.svm_predict_probability(model, svmVector, confidenceScores);

/*System.out.println("text="+ text);
for (int i = 0; i < model.nr_class; i++) {
    System.out.println("i=" + i + ", labelNum:" + model.label[i] + ", name=" + classLoadMap.get(model.label[i]) + ", score="+confidenceScores[i]);
}*/

//finding max confidence; 
int maxConfidenceIndex = 0;
double maxConfidence = confidenceScores[maxConfidenceIndex];
for (int i = 1; i < confidenceScores.length; i++) {
    if(confidenceScores[i] > maxConfidence){
        maxConfidenceIndex = i;
        maxConfidence = confidenceScores[i];
    }
}

double threshold = 0.3; // set this based data & no. of classes
int labelNum = model.label[maxConfidenceIndex];
// reverse map number to name
String targetClassLabel = classLoadMap.get(labelNum); 
LOG.info("classNumber:{}, className:{}; confidence:{}; for text:{}",
        labelNum, targetClassLabel, (maxConfidence), text);
if (maxConfidence < threshold ) {
    LOG.info("Not enough confidence; threshold={}", threshold);
    targetClassLabel = null;
}
return targetClassLabel;
like image 21
Thamme Gowda Avatar answered Jan 24 '23 16:01

Thamme Gowda