Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Knowledge Distillation loss with Tensorflow 2 + Keras

I am trying to implement a very simple keras model that uses Knowledge Distillation [1] from another model. Roughly, I need to replace the original loss L(y_true, y_pred) by L(y_true, y_pred)+L(y_teacher_pred, y_pred) where y_teacher_pred is the prediction of another model.

I've tried to do

def create_student_model_with_distillation(teacher_model):

  inp = tf.keras.layers.Input(shape=(21,))

  model = tf.keras.models.Sequential()
  model.add(inp)

  model.add(...) 
  model.add(tf.keras.layers.Dense(units=1))

  teacher_pred = teacher_model(inp)

  def my_loss(y_true,y_pred):
      loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
      loss += tf.keras.losses.mean_squared_error(teacher_pred, y_pred)
      return loss

  model.compile(loss=my_loss, optimizer='adam')

  return model

However, when I try to call fit on my model, I am getting

TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.

How can I solve this issue ?

Refs

[1] https://arxiv.org/abs/1503.02531

like image 611
ThR37 Avatar asked Nov 07 '22 11:11

ThR37


1 Answers

Actually, this blogpost is answer to your question: keras blog

But in short - you should use new TF2 API and call teacher's predict before the tf.GradientTape() block:

def train_step(self, data):
        # Unpack data
        x, y = data

        # Forward pass of teacher
        teacher_predictions = self.teacher(x, training=False)

        with tf.GradientTape() as tape:
            # Forward pass of student
            student_predictions = self.student(x, training=True)

            # Compute losses
            student_loss = self.student_loss_fn(y, student_predictions)
            distillation_loss = self.distillation_loss_fn(
                tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
                tf.nn.softmax(student_predictions / self.temperature, axis=1),
            )
            loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
like image 181
Alex Glinsky Avatar answered Nov 14 '22 23:11

Alex Glinsky