Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python - Pandas, Resample dataset to have balanced classes

With the following data frame, with only 2 possible lables:

   name  f1  f2  label
0     A   8   9      1
1     A   5   3      1
2     B   8   9      0
3     C   9   2      0
4     C   8   1      0
5     C   9   1      0
6     D   2   1      0
7     D   9   7      0
8     D   3   1      0
9     E   5   1      1
10    E   3   6      1
11    E   7   1      1

I've written a code to group the data by the 'name' column and pivot the result into a numpy array, so each row is a collection of all the samples of a specific group, and the lables are another numpy array:

Data:

[[8 9] [5 3] [0 0]] # A lable = 1
[[8 9] [0 0] [0 0]] # B lable = 0
[[9 2] [8 1] [9 1]] # C lable = 0
[[2 1] [9 7] [3 1]] # D lable = 0
[[5 1] [3 6] [7 1]] # E lable = 1

Lables:

[[1]
 [0]
 [0]
 [0]
 [1]]

Code:

import pandas as pd
import numpy as np


def prepare_data(group_name):
    df = pd.read_csv("../data/tmp.csv")


    group_index = df.groupby(group_name).cumcount()
    data = (df.set_index([group_name, group_index])
            .unstack(fill_value=0).stack())



    target = np.array(data['label'].groupby(level=0).apply(lambda x: [x.values[0]]).tolist())
    data = data.loc[:, data.columns != 'label']
    data = np.array(data.groupby(level=0).apply(lambda x: x.values.tolist()).tolist())
    print(data)
    print(target)


prepare_data('name')

I would like to resample and delete instances from the over-represented class.

i.e

[[8 9] [5 3] [0 0]] # A lable = 1
[[8 9] [0 0] [0 0]] # B lable = 0
[[9 2] [8 1] [9 1]] # C lable = 0
# group D was deleted randomly from the '0' labels 
[[5 1] [3 6] [7 1]] # E lable = 1

would be an acceptable solution, since removing D (labeled '0') will result with a balanced dataset of 2 * label '1' and 2 * label '0'.

like image 302
Shlomi Schwartz Avatar asked Oct 10 '18 07:10

Shlomi Schwartz


1 Answers

A very simple approach. Taken from sklearn documentation and Kaggle.

from sklearn.utils import resample

df_majority = df[df.label==0]
df_minority = df[df.label==1]

# Upsample minority class
df_minority_upsampled = resample(df_minority, 
                                 replace=True,     # sample with replacement
                                 n_samples=20,    # to match majority class
                                 random_state=42) # reproducible results

# Combine majority class with upsampled minority class
df_upsampled = pd.concat([df_majority, df_minority_upsampled])

# Display new class counts
df_upsampled.label.value_counts()
like image 123
Joe9008 Avatar answered Sep 28 '22 02:09

Joe9008