Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Deep multi-task learning with missing labels

I have a deep multi-task network that is designed to work on three separate classification tasks. While each training example has a label for two of the classification tasks, only about 10% to 15% of the training examples have a label for the third task. The network has multiple shared layers followed by a separate head for each task composed of one or more fully connected layer and a softmax/ sigmoid output layer.

In order to handle the missing labels in the third task, I used tf.boolean_mask to mask the examples in each batch that do not have a label, which works great except in the rare occasion when a batch has no training examples with labels; i.e. there are no labels for task 3 in the entire batch. In this case, the boolean mask (correctly) returns an empty tensor and tf.softmax_cross_entropy_with_logits returns nan raising an error during training.

My current solution to this problem is just to check whether a batch has no labels for the third task, and, if it does, skip the batch during training. While this avoids the error, I was wondering whether I can edit the computational graph to handle this relative rare occurrence so I do not have to skip batches.

Here is a snippet of the code for the the third task's output layer and the total loss function. There are fully connected layers for this task and multiple shared layers before this output layer.

    # softmax output layer for natural categories
    with tf.variable_scope('Natural_Category_Output'):
        W = tf.get_variable('W', shape = [natural_layer_size, no_natural_categories], 
                            initializer = tf.glorot_uniform_initializer())
        b = tf.get_variable('b', shape = [no_natural_categories],
                            initializer = tf.glorot_uniform_initializer())

        natural_logits = tf.add(tf.matmul(natural_output, W), b, name = 'logits')
        masked_logits = tf.boolean_mask(natural_logits, natural_mask, axis = 0, name = 'masked_logits')

        natural_probabilities = tf.nn.softmax(natural_logits, name = 'probabilities')
        natural_predictions = tf.argmax(natural_logits, axis = 1, name = 'predictions')
        masked_predictions = tf.boolean_mask(natural_predictions, natural_mask, axis = 0, name = 'masked_predictions')

    # loss for the natural categories
    with tf.variable_scope('Natural_Category_Loss_Function'):
        masked_natural_category = tf.boolean_mask(binarized_natural_category, natural_mask, axis = 0, name = 'masked_natural_categories')
        natural_loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels = masked_natural_category, 
                                                                                 logits = masked_logits), name = 'cross_entropy_loss')

        if uncertainty_weighting:
            # intialize weight variables
            natural_weight = tf.get_variable('natural_weight', shape = [], initializer = tf.constant_initializer(1.0))

            # augment the the loss function for the task
            natural_loss = tf.add(tf.divide(natural_loss, tf.multiply(tf.constant(2.0), tf.square(natural_weight))),
                                  tf.log(tf.square(natural_weight)), name = 'weighted_loss')

    # total loss function 
    with tf.variable_scope('Total_Loss'):
        loss = fs_loss + expense_loss + natural_loss

Does anyone have any way of changing the graph to handle batches with no labels?

like image 713
phil510 Avatar asked Mar 11 '18 01:03

phil510


1 Answers

Basically, you're doing right. Another way of doing this is to use 'tf.gather' before calculating loss. Let's assume that samples don't have labels have '-1' as label.

valid_idxs = tf.where(your_label > -1)[:, 0]
valid_logits = tf.gather(your_logits, valid_idxs)
valid_labels = tf.gather(your_label, valid_idxs)
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=valid_labels, logits=valid_logits)
like image 156
Minz Avatar answered Nov 15 '22 02:11

Minz