Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

sklearn: LogisticRegression - predict_proba(X) - calculation

I was wondering if someone can maybe have a quick look at the following code snippet and point me in a direction to find my misunderstanding in calculating the probability of a sample for each class in the model and my related code bug. I tried to manually calculate the results provided by the sklearn function lm.predict_proba(X) , sadly the results are different, so i did a mistake.

I think the bug will be in part "d" of the following code walkthrough. Maybe in the math, but I could not see why.

a) Creating and training a logistic regression model ( works fine )

lm = LogisticRegression(random_state=413, multi_class='multinomial', solver='newton-cg')
lm.fit(X, train_labels)

b) Saving coefficient and bias ( works fine )

W = lm.coef_
b = lm.intercept_

c) Using lm.predict_proba(X) ( works fine)

def reshape_single_element(x,num):
    singleElement = x[num]
    nx,ny = singleElement.shape
    return  singleElement.reshape((1,nx*ny))

select_image_number = 6 
X_select_image_data=reshape_single_element(train_dataset,select_image_number)
Y_probabilities =  lm.predict_proba(X_select_image_data)
Y_pandas_probabilities = pd.Series(Y_probabilities[0], index=['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'])
print"estimate probabilities for each class: \n" ,Y_pandas_probabilities , "\n"
print"all probabilities by lm.predict_proba(..) sum up to ", np.sum(Y_probabilities) , "\n"

The output was:

estimate probabilities for each class: 
a 0.595426
b 0.019244
c 0.001343
d 0.004033
e 0.017185
f 0.004193
g 0.160380
h 0.158245
i 0.003093
j 0.036860
dtype: float64 
all probabilities by lm.predict_proba(..) sum up to 1.0

d) Manually performing the calculation done by lm.predict_proba ( no error/warning, but results are not the same )

manual_calculated_probabilities = []
for select_class_k in range(0,10):  #a=0. b=1, c=3 ...
    z_for_class_k = (np.sum(W[select_class_k] *X_select_image_data) + b[select_class_k] )
    p_for_class_k = 1/ (1 + math.exp(-z_for_class_k))
    manual_calculated_probabilities.append(p_for_class_k)

print "formula: ", manual_calculated_probabilities , "\n"

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e = np.exp(x)
    dist = e / np.sum(np.exp(x),axis=0)
    return dist

abc = softmax(manual_calculated_probabilities)
print "softmax:" , abc

The output was:

formula: [0.9667598370531315, 0.48453459121301334, 0.06154496922245115, 0.16456194859398865, 0.45634781280053394, 0.16999340794727547, 0.8867996361191054, 0.8854473986336552, 0.13124464656251109, 0.642913996162282]

softmax: [ 0.15329642 0.09464644 0.0620015 0.0687293 0.0920159 0.069103610.14151607 0.14132483 0.06647715 0.11088877]

Softmax was used, because of a comment at github logistic.py

For a multi_class problem, if multi_class is set to be "multinomial" the softmax function is used to find the predicted probability of each class.

Note:

print "shape of X: " , X_select_image_data.shape
print "shape of W: " , W.shape
print "shape of b: " , b.shape

shape of X:  (1, 784)
shape of W:  (10, 784)
shape of b:  (10,)

I found a very similar question here, but sadly I could not adapted it to my code so the predictions got the same. I tried many different combinations to calculate the variables 'z_for_class_k' and 'p_for_class_k' but sadly without success to reproduce the prediction values from 'predict_proba(X)'.

like image 574
Ted Frank Avatar asked Nov 21 '22 15:11

Ted Frank


1 Answers

I think the problem is with

p_for_class_k = 1/ (1 + math.exp(-z_for_class_k))

1 / (1 + exp(-logit)) is a simplification that works only on binary problems.

The real equation, before being simplified, looks like this:

p_for_classA = exp(logit_classA) / [1 + exp(logit_classA) + exp(logit_classB) ... + exp(logit_classC)]

In other words, when calculating a probability for a specific class, you must incorporate ALL the weights and biases from the other classes as well into your formula.

I didn't have the data to test this out, but hopefully this points you in the right direction.

like image 90
Josh Avatar answered Dec 29 '22 08:12

Josh