Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I align pandas get_dummies across training and test data?

Tags:

python

pandas

This question was helpful in realizing that I can split training and validation data. Here is the code I use to load my train and test.

def load_data(datafile):
    training_data = pd.read_csv(datafile, header=0, low_memory=False)
    training_y = training_data[['job_performance']]
    training_x = training_data.drop(['job_performance'], axis=1)

    training_x.replace([np.inf, -np.inf], np.nan, inplace=True)
    training_x.fillna(training_x.mean(), inplace=True)
    training_x.fillna(0, inplace=True)
    categorical_data = training_x.select_dtypes(
        include=['category', object]).columns

    training_x = pd.get_dummies(training_x, columns=categorical_data)
    return training_x, training_y

Where the datafile is my training file. I have another file, test.csv that has the same columns as the training file, except it may be missing categories. How can I do the get_dummies across the test file and ensure the categories are encoded in the same way?

Additionally, my test data is missing job_performance column, how can I handle this in the function?

like image 536
Shamoon Avatar asked Sep 01 '25 23:09

Shamoon


2 Answers

When dealing with multiple columns, it is best to use sklearn.preprocessing.OneHotEncoder. This is good at keeping track of your categories and handles unknown categories gracefully.

sys.version
# '3.6.0 (v3.6.0:41df79263a11, Dec 22 2016, 17:23:13) \n[GCC 4.2.1 (Apple Inc. build 5666) (dot 3)]'
sklearn.__version__
# '0.20.0'
np.__version__
# '1.15.0'
pd.__version__
# '0.24.2'

from sklearn.preprocessing import OneHotEncoder

df = pd.DataFrame({
    'data': [1, 2, 3],
    'cat1': ['a', 'b', 'c'],
    'cat2': ['dog', 'cat', 'bird']
})

ohe = OneHotEncoder(handle_unknown='ignore')
categorical_columns = df.select_dtypes(['category', object]).columns
dummies = pd.DataFrame(ohe.fit_transform(df[categorical_columns]).toarray(), 
                       index=df.index, 
                       dtype=int)

df_ohe = pd.concat([df.drop(categorical_columns, axis=1), dummies], axis=1)
df_ohe

   data  0  1  2  3  4  5
0     1  1  0  0  0  0  1
1     2  0  1  0  0  1  0
2     3  0  0  1  1  0  0

You can see the categories and their ordering:

 ohe.categories_
# [array(['a', 'b', 'c'], dtype=object),
#  array(['bird', 'cat', 'dog'], dtype=object)]

Now, to reverse the process, we just need the categories from before. No need to pickle or unpickle any models here.

df2 = pd.DataFrame({
    'data': [1, 2, 1],
    'cat1': ['b', 'a', 'b'],
    'cat2': ['cat', 'dog', 'cat']
})

ohe2 = OneHotEncoder(categories=ohe.categories_)
ohe2.fit_transform(df2[categorical_columns])

dummies = pd.DataFrame(ohe2.fit_transform(df2[categorical_columns]).toarray(), 
                       index=df2.index, 
                       dtype=int)
pd.concat([df2.drop(categorical_columns, axis=1), dummies], axis=1)

   data  0  1  2  3  4  5
0     1  0  1  0  0  1  0
1     2  1  0  0  0  0  1
2     1  0  1  0  0  1  0

So what does this mean for you? You'll want to change your function to work for both train and test data. Add an extra parameter categories to your function.

def load_data(datafile, categories=None):
    data = pd.read_csv(datafile, header=0, low_memory=False)
    if 'job_performance' in data.keys():
        data_y = data[['job_performance']]
        data_x = data.drop(['job_performance'], axis=1)
    else:
        data_x = data
        data_y = None

    data_x.replace([np.inf, -np.inf], np.nan, inplace=True)
    data_x.fillna(data_x.mean(), inplace=True)
    data_x.fillna(0, inplace=True)

    ohe = OneHotEncoder(handle_unknown='ignore', 
                        categories=categories if categories else 'auto')

    categorical_data = data_x.select_dtypes(object)
    dummies = pd.DataFrame(
                ohe.fit_transform(categorical_data.astype(str)).toarray(), 
                index=data_x.index,
                dtype=int)

    data_x = pd.concat([
        data_x.drop(categorical_data.columns, axis=1), dummies], axis=1)

    return (data_x, data_y) + ((ohe.categories_, ) if not categories else ())

Your function can be called as,

# Load training data.
X_train, y_train, categories = load_data('train.csv')
...
# Load test data.
X_test, y_test = load_data('test.csv', categories=categories)

And the code should work fine.

like image 182
cs95 Avatar answered Sep 03 '25 12:09

cs95


If you want to use pandas get_dummies you will need to manually add columns for values in train but not in test and ignore columns in test but not in train.

You could use the dummies column names ('origcolumn_value' by default) to do that, and use separate functions for train and test.

Something along these lines (haven't tested it):

def load_and_clean(datafile_path, labeled=False):
    data = pd.read_csv(datafile_path, header=0, low_memory=False)

    if labeled:
        job_performance = data['job_performance']
        data = data.drop(['job_performance'], axis=1)

    data.replace([np.inf, -np.inf], np.nan, inplace=True)
    data.fillna(data.mean(), inplace=True)
    data.fillna(0, inplace=True)

    if labeled:
        data['job_performance'] = job_performance

    return data

def dummies_train(training_data):
    training_y = training_data[['job_performance']]
    training_x = data.drop(['job_performance'], axis=1)
    categorical_data = training_x.select_dtypes(
        include=['category', object]).columns
    training_x = pd.get_dummies(training_x, columns=categorical_data)
    return training_x, training_y, training_x.columns

def dummies_test(test_data, model_columns):
    categorical_data = test_data.select_dtypes(
        include=['category', object]).columns
    test_data = pd.get_dummies(test_data, columns=categorical_data)
    for c in model_columns:
        if c not in test_data.columns:
            test_data[c] = 0
    return test_data[model_columns]

training_x, training_y, model_columns = dummies_train(load_and_clean(<train_data_path>), labeled=True)
test_x = dummies_test(load_and_clean(<test_data_path>), model_columns)
like image 30
Ezer K Avatar answered Sep 03 '25 11:09

Ezer K