Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Retain pandas dataframe structure after SMOTE, oversampling in python

Problem: While implementing SMOTE (a type of oversampling) , my dataframe is getting converted to numpy array).

Test_train_split

from sklearn.model_selection import train_test_split
X_train, X_test, y_train_full, y_test_full = train_test_split(X, y, test_size=0.20, random_state=66)
[IN]type(X_train)
[OUT]pandas.core.frame.DataFrame

After SMOTE, datatype of X_train changes from pandas dataframe to numpy array

from imblearn.over_sampling import SMOTE
sm = SMOTE(random_state = 42)
X_train, y_train = sm.fit_sample(X_train, y_train)
[IN]type(X_train)
[OUT]numpy.ndarray

Expected output I want to retain the dataframe structure of X_train and X_test after SMOTE. How to do that?

like image 849
noob Avatar asked Dec 07 '22 10:12

noob


2 Answers

I found a simpler answer:

from imblearn.over_sampling import SMOTE
sm = SMOTE(random_state = 42)
X_train_oversampled, y_train_oversampled = sm.fit_sample(X_train, y_train)
X_train = pd.DataFrame(X_train_oversampled, columns=X_train.columns)

This helps retain dataframe structure after SMOTE

like image 55
noob Avatar answered Dec 09 '22 22:12

noob


This function may help you. df is X_train and X_test in your case and output is column name of y as string. SEED is random int in case of if you want to set random_state.

You can use it after split or before split your dataset, depends on your choice.

def smote_sampler(df, output, SEED=33):
     X = df.drop([output], axis=1)
     y = df[output]
     col_names = pd.concat([X, y], axis=1).columns.tolist()
     smt = SMOTE(random_state=SEED)
     X_smote, y_smote = smt.fit_sample(X, y)
     smote_array = np.concatenate([X_smote, y_smote.reshape(-1, 1)], axis=1)
     df_ = pd.DataFrame(smote_array, columns=col_names)
     smote_cols = df_.columns.tolist()
     org_int_cols = df.dtypes.index[df.dtypes == 'int64'].tolist()
     org_float_cols = df.dtypes.index[df.dtypes == 'float64'].tolist()
     try:
         for col in smote_cols:
             if col in org_float_cols:
                 df_[col] = df_[col].astype('float64')
             elif col in org_int_cols:
                 df_[col] = df_[col].astype('int64')
     except:
         raise ValueError
     return df_
like image 37
talatccan Avatar answered Dec 10 '22 00:12

talatccan