Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can GridSearchCV use predict_proba when using a custom score function?

Tags:

scikit-learn

I am trying to use a custom scoring function that calculates multi-class log loss with the ground truth and predict_proba y array. Is there a way to make GridSearchCV use this scoring function?

def multiclass_log_loss(y_true, y_pred):
Parameters
----------
y_true : array, shape = [n_samples]
        true class, intergers in [0, n_classes - 1)
y_pred : array, shape = [n_samples, n_classes]

Returns
-------
loss : float
"""
eps=1e-15
predictions = np.clip(y_pred, eps, 1 - eps)

# normalize row sums to 1
predictions /= predictions.sum(axis=1)[:, np.newaxis]

actual = np.zeros(y_pred.shape)
n_samples = actual.shape[0]
actual[np.arange(n_samples), y_true.astype(int)] = 1
vectsum = np.sum(actual * np.log(predictions))
loss = -1.0 / n_samples * vectsum
return loss

I see that there are multiple options, score_func, loss_func and make_scorer. I tried using make_scorer with greater_is_better=False and also tried the loss_func parameter but it seems to still use the .predict method. How can I get around this problem?

UPDATE - if I set needs_threshold=True I get a multi-class error. Am I correct to understand multi-class is not supported in this case? If yes, can someone suggest a workaround?

Thanks.

like image 615
shankarmsy Avatar asked Nov 10 '22 20:11

shankarmsy


1 Answers

The top answer to this question: Pass estimator to custom score function via sklearn.metrics.make_scorer

might have what you need. One can define a scorer that takes as arguments a classifier clf, feature array X, and targets y_true, and feed the result of the clf.predict_proba() method to a scoring function that returns the error. As a hint, for binary classification, you probably need to use

clf.predict_proba(X)[:,1]

This worked for my needs (a normalized Gini score). For some reason, I couldn't get sklearn's metrics.make_scorer to work with my custom function that needs probabilities.

like image 184
Don Ernesto Avatar answered Jan 04 '23 03:01

Don Ernesto