Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tensorflow cross entropy loss for sequence with different lengths

Tags:

tensorflow

i'm building a seq2seq model with LSTM using tensorflow. The loss function i'm using is the softmax cross entropy loss. The problem is my input sequences have different lenghts so i padded it. The output of the model have the shape [max_length, batch_size, vocab_size]. How can i calculate the loss that the 0 padded values don't affect the loss? tf.nn.softmax_cross_entropy_with_logits provide axis parameter so we can calculate the loss with 3-dimention but it doesn't provide weights. tf.losses.softmax_cross_entropy provides weights parameter but it recieves input with shape [batch_size, nclass(vocab_size)]. Please help!

like image 841
Khoa Ngo Avatar asked Jan 28 '26 17:01

Khoa Ngo


2 Answers

I think you'd have to write your own loss function. Check out https://danijar.com/variable-sequence-lengths-in-tensorflow/.

like image 195
Abhai Kollara Avatar answered Jan 31 '26 19:01

Abhai Kollara


In this case you need to pad the two logits and labels so that they have the same length. So, if you have the tensors logits with the size of (batch_size, length, vocab_size) and labels with the size of (batch_size, length) in which length is the size of your sequence. First, you have to pad them to same length:

def _pad_tensors_to_same_length(logits, labels):
    """Pad x and y so that the results have the same length (second dimension)."""
    with tf.name_scope("pad_to_same_length"):
        logits_length = tf.shape(logits)[1]
        labels_length = tf.shape(labels)[1]

        max_length = tf.maximum(logits_length, labels_length)

        logits = tf.pad(logits, [[0, 0], [0, max_length - logits_length], [0, 0]])
        labels = tf.pad(labels, [[0, 0], [0, max_length - labels_length]])
        return logits, labels

Then you can do the padded cross entropy:

def padded_cross_entropy_loss(logits, labels, vocab_size):
  """Calculate cross entropy loss while ignoring padding.

  Args:
    logits: Tensor of size [batch_size, length_logits, vocab_size]
    labels: Tensor of size [batch_size, length_labels]
    vocab_size: int size of the vocabulary
  Returns:
    Returns the cross entropy loss 
  """
  with tf.name_scope("loss", values=[logits, labels]):
    logits, labels = _pad_tensors_to_same_length(logits, labels)

    # Calculate cross entropy
    with tf.name_scope("cross_entropy", values=[logits, labels]):
      xentropy = tf.nn.softmax_cross_entropy_with_logits_v2(
          logits=logits, labels=targets)

    weights = tf.to_float(tf.not_equal(labels, 0))
    return xentropy * weights
like image 28
Rohola Zandie Avatar answered Jan 31 '26 18:01

Rohola Zandie