Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why can't statsmodels reproduce my R logistic regression results?

I'm confused about why my logistic regression models in R and statsmodels do not agree.

If I prepare some data in R with

# From https://courses.edx.org/c4x/MITx/15.071x/asset/census.csv
library(caTools) # for sample.split
census = read.csv("census.csv")
set.seed(2000)
split = sample.split(census$over50k, SplitRatio = 0.6)
censusTrain = subset(census, split==TRUE)
censusTest = subset(census, split==FALSE)

and then run a logistic regression with

CensusLog1 = glm(over50k ~., data=censusTrain, family=binomial)

I see results like

                                           Estimate Std. Error z value Pr(>|z|)    
(Intercept)                              -8.658e+00  1.379e+00  -6.279 3.41e-10 ***
age                                       2.548e-02  2.139e-03  11.916  < 2e-16 ***
workclass Federal-gov                     1.105e+00  2.014e-01   5.489 4.03e-08 ***
workclass Local-gov                       3.675e-01  1.821e-01   2.018 0.043641 *  
workclass Never-worked                   -1.283e+01  8.453e+02  -0.015 0.987885    
workclass Private                         6.012e-01  1.626e-01   3.698 0.000218 ***
workclass Self-emp-inc                    7.575e-01  1.950e-01   3.884 0.000103 ***
workclass Self-emp-not-inc                1.855e-01  1.774e-01   1.046 0.295646    
workclass State-gov                       4.012e-01  1.961e-01   2.046 0.040728 *  
workclass Without-pay                    -1.395e+01  6.597e+02  -0.021 0.983134   
...

but of I use the same data in Python, by first exporting from R with

write.csv(censusTrain,file="traincensus.csv")
write.csv(censusTest,file="testcensus.csv")

and then importing into Python with

import pandas as pd

census = pd.read_csv("census.csv")
census_train = pd.read_csv("traincensus.csv")
census_test = pd.read_csv("testcensus.csv")

I get errors and strange results that bear no relationship to the ones I get in R.

If I simply try

import statsmodels.api as sm

census_log_1 = sm.Logit.from_formula(f, census_train).fit()

I get an error:

ValueError: operands could not be broadcast together with shapes (19187,2) (19187,) 

Even if prepare the data with patsy using

import patsy
f = 'over50k ~ ' + ' + '.join(list(census.columns)[:-1])
y, X = patsy.dmatrices(f, census_train, return_type='dataframe')

trying

census_log_1 = sm.Logit(y, X).fit()

results in the same error. The only way I can avoid errors is to use use GLM

census_log_1 = sm.GLM(y, X, family=sm.families.Binomial()).fit()

but this produces results that are entirely different from those produced by (what I thought was) the equivalent R API:

                                                   coef    std err          t      P>|t|      [95.0% Conf. Int.]
----------------------------------------------------------------------------------------------------------------
Intercept                                       10.6766      5.985      1.784      0.074        -1.055    22.408
age                                             -0.0255      0.002    -11.916      0.000        -0.030    -0.021
workclass[T. Federal-gov]                       -0.9775      4.498     -0.217      0.828        -9.794     7.839
workclass[T. Local-gov]                         -0.2395      4.498     -0.053      0.958        -9.055     8.576
workclass[T. Never-worked]                       8.8346    114.394      0.077      0.938      -215.374   233.043
workclass[T. Private]                           -0.4732      4.497     -0.105      0.916        -9.288     8.341
workclass[T. Self-emp-inc]                      -0.6296      4.498     -0.140      0.889        -9.446     8.187
workclass[T. Self-emp-not-inc]                  -0.0576      4.498     -0.013      0.990        -8.873     8.758
workclass[T. State-gov]                         -0.2733      4.498     -0.061      0.952        -9.090     8.544
workclass[T. Without-pay]                       10.0745     85.048      0.118      0.906      -156.616   176.765
...

Why is logistic regression in Python producing errors and different results from those produced by R? Are these APIs not in fact equivalent (I've had them work before to produce identical results)? Is there some additional processing of the datasets required to make them usable by statsmodels?

like image 955
orome Avatar asked Oct 02 '22 04:10

orome


1 Answers

The error is due to the fact that patsy expands the LHS variable to be a full Treatement contrast. Logit does not handle this as indicated in the docstring, but as you see GLM with binomial family does.

I can't speak to the difference in the results without a full output. In all likelihood it's different default handling of categorical variables or you're using different variables. Not all are listed in your output.

You can use logit by doing the following pre-processing step.

census = census.replace(to_replace={'over50k' : {' <=50K' : 0, ' >50K' : 1}})

Note also that the default solver for logit doesn't seem to work all that well for this problem. It runs into a singular matrix problem. Indeed, the condition number for this problem is huge, and what you get in R might not be a fully converged model. You might try reducing your number of dummy variables.

[~/]
[73]: np.linalg.cond(mod.exog)
[73]: 4.5139498536894682e+17

I had to use the following to get a solution

mod = sm.formula.logit(f, data=census)
res = mod.fit(method='bfgs', maxiter=1000)    

Some of your cells end up being very small. This is compounded by the other sparse dummy variables.

[~/]
[81]: pd.Categorical(census.occupation).describe()
[81]: 
                    counts     freqs
levels                              
?                    1816  0.056789
Adm-clerical         3721  0.116361
Armed-Forces            9  0.000281
Craft-repair         4030  0.126024
Exec-managerial      3992  0.124836
Farming-fishing       989  0.030928
Handlers-cleaners    1350  0.042217
Machine-op-inspct    1966  0.061480
Other-service        3212  0.100444
Priv-house-serv       143  0.004472
Prof-specialty       4038  0.126274
Protective-serv       644  0.020139
Sales                3584  0.112077
Tech-support          912  0.028520
Transport-moving     1572  0.049159
like image 168
jseabold Avatar answered Oct 13 '22 12:10

jseabold