I would like to ask if current API of datasets allows for implementation of oversampling algorithm? I deal with highly imbalanced class problem. I was thinking that it would be nice to oversample specific classes during dataset parsing i.e. online generation. I've seen the implementation for rejection_resample function, however this removes samples instead of duplicating them and its slows down batch generation (when target distribution is much different then initial one). The thing I would like to achieve is: to take an example, look at its class probability decide if duplicate it or not. Then call dataset.shuffle(...)
dataset.batch(...)
and get iterator. The best (in my opinion) approach would be to oversample low probable classes and subsample most probable ones. I would like to do it online since it's more flexible.
To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method. The former produces a tensor, which is recommended.
Random oversampling involves randomly selecting examples from the minority class, with replacement, and adding them to the training dataset. Random undersampling involves randomly selecting examples from the majority class and deleting them from the training dataset.
SMOTE is an oversampling method that synthesizes new plausible examples in the minority class. Tomek Links refers to a method for identifying pairs of nearest neighbors in a dataset that have different classes.
SMOTE is an oversampling technique and creates new minority class synthetic samples, and Tomek Links is an undersampling technique. For an imbalanced dataset, first SMOTE is applied to create new synthetic minority samples to get a balanced distribution.
This problem has been solved in issue #14451. Just posting the anwser here to make it more visible to other developers.
The sample code is oversampling low frequent classes and undersampling high frequent ones, where class_target_prob
is just uniform distribution in my case. I wanted to check some conclusions from recent manuscript A systematic study of the class imbalance problem in convolutional neural networks
The oversampling of specific classes is done by calling:
dataset = dataset.flat_map(
lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)
Here is the full snippet which does all the things:
# sampling parameters
oversampling_coef = 0.9 # if equal to 0 then oversample_classes() always returns 1
undersampling_coef = 0.5 # if equal to 0 then undersampling_filter() always returns True
def oversample_classes(example):
"""
Returns the number of copies of given example
"""
class_prob = example['class_prob']
class_target_prob = example['class_target_prob']
prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
# soften ratio is oversampling_coef==0 we recover original distribution
prob_ratio = prob_ratio ** oversampling_coef
# for classes with probability higher than class_target_prob we
# want to return 1
prob_ratio = tf.maximum(prob_ratio, 1)
# for low probability classes this number will be very large
repeat_count = tf.floor(prob_ratio)
# prob_ratio can be e.g 1.9 which means that there is still 90%
# of change that we should return 2 instead of 1
repeat_residual = prob_ratio - repeat_count # a number between 0-1
residual_acceptance = tf.less_equal(
tf.random_uniform([], dtype=tf.float32), repeat_residual
)
residual_acceptance = tf.cast(residual_acceptance, tf.int64)
repeat_count = tf.cast(repeat_count, dtype=tf.int64)
return repeat_count + residual_acceptance
def undersampling_filter(example):
"""
Computes if given example is rejected or not.
"""
class_prob = example['class_prob']
class_target_prob = example['class_target_prob']
prob_ratio = tf.cast(class_target_prob/class_prob, dtype=tf.float32)
prob_ratio = prob_ratio ** undersampling_coef
prob_ratio = tf.minimum(prob_ratio, 1.0)
acceptance = tf.less_equal(tf.random_uniform([], dtype=tf.float32), prob_ratio)
return acceptance
dataset = dataset.flat_map(
lambda x: tf.data.Dataset.from_tensors(x).repeat(oversample_classes(x))
)
dataset = dataset.filter(undersampling_filter)
dataset = dataset.repeat(-1)
dataset = dataset.shuffle(2048)
dataset = dataset.batch(32)
sess.run(tf.global_variables_initializer())
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
Here is a simple jupyter notebook which implements the above oversampling/undersampling on a toy model.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With