I am trying to implement a very simple keras model that uses Knowledge Distillation [1] from another model.
Roughly, I need to replace the original loss L(y_true, y_pred)
by L(y_true, y_pred)+L(y_teacher_pred, y_pred)
where y_teacher_pred
is the prediction of another model.
I've tried to do
def create_student_model_with_distillation(teacher_model):
inp = tf.keras.layers.Input(shape=(21,))
model = tf.keras.models.Sequential()
model.add(inp)
model.add(...)
model.add(tf.keras.layers.Dense(units=1))
teacher_pred = teacher_model(inp)
def my_loss(y_true,y_pred):
loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
loss += tf.keras.losses.mean_squared_error(teacher_pred, y_pred)
return loss
model.compile(loss=my_loss, optimizer='adam')
return model
However, when I try to call fit
on my model, I am getting
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
How can I solve this issue ?
Refs
[1] https://arxiv.org/abs/1503.02531
Actually, this blogpost is answer to your question: keras blog
But in short - you should use new TF2 API and call teacher's predict
before the tf.GradientTape()
block:
def train_step(self, data):
# Unpack data
x, y = data
# Forward pass of teacher
teacher_predictions = self.teacher(x, training=False)
with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With