Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

GradienTape convergence much slower than Keras.model.fit

I am currently trying to get a hold of the TF2.0 api, but as I compared the GradientTape to a regular keras.Model.fit I noticed:

  1. It ran slower(probably due to the Eager Execution)

  2. It converged much slower (and I am not sure why).

+--------+--------------+--------------+------------------+
|  Epoch | GradientTape | GradientTape | keras.Model.fit  |
|        |              |  shuffling   |                  |
+--------+--------------+--------------+------------------+
|    1   |     0.905    |     0.918    |      0.8793      |
+--------+--------------+--------------+------------------+
|    2   |     0.352    |     0.634    |      0.2226      |
+--------+--------------+--------------+------------------+
|    3   |     0.285    |     0.518    |      0.1192      |
+--------+--------------+--------------+------------------+
|    4   |     0.282    |     0.458    |      0.1029      |
+--------+--------------+--------------+------------------+
|    5   |     0.275    |     0.421    |      0.0940      |
+--------+--------------+--------------+------------------+

Here is the training loop I used with the GradientTape:


optimizer = keras.optimizers.Adam()
glove_model = GloveModel(vocab_size=len(labels))
train_loss = keras.metrics.Mean(name='train_loss')

@tf.function
def train_step(examples, labels):
    with tf.GradientTape() as tape:
        predictions = glove_model(examples)
        loss = glove_model.glove_loss(labels, predictions)

    gradients = tape.gradient(loss, glove_model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, glove_model.trainable_variables))

    train_loss(loss)



total_step = 0
for epoch in range(epochs_number):

    pbar = tqdm(train_ds.enumerate(), total=int(len(index_data) / batch_size) + 1)

    for ix, (examples, labels) in pbar:

        train_step(examples, labels)


    print(f"Epoch {epoch + 1}, Loss {train_loss.result()}")

    # Reset the metrics for the next epoch
    train_loss.reset_states()

And here is the Keras.Model.fit training:

glove_model.compile(optimizer, glove_model.glove_loss)
glove_model.fit(train_ds, epochs=epochs_number)

Here is the tf.data.Dataset source

train_ds = data.Dataset.from_tensor_slices(
    (np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1)]), index_data)
).shuffle(100000).batch(batch_size, drop_remainder=True)

And Here is the model.

class GloveModel(keras.Model):

    def __init__(self, vocab_size, dim=100, a=3/4, x_max=100):
        super(GloveModel, self).__init__()

        self.vocab_size = vocab_size
        self.dim = dim
        self.a = a
        self.x_max = x_max

        self.target_embedding = layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.dim, input_length=1, name="target_embedding"
        )
        self.target_bias = layers.Embedding(
            input_dim=self.vocab_size, output_dim=1, input_length=1, name="target_bias"
        )

        self.context_embedding = layers.Embedding(
            input_dim=self.vocab_size, output_dim=self.dim, input_length=1, name="context_embedding"
        )
        self.context_bias = layers.Embedding(
            input_dim=self.vocab_size, output_dim=1, input_length=1, name="context_bias"
        )

        self.dot_product = layers.Dot(axes=-1, name="dot")

        self.prediction = layers.Add(name="add")
        self.step = 0

    def call(self, inputs):

        target_ix = inputs[:, 0]
        context_ix = inputs[:, 1]

        target_embedding = self.target_embedding(target_ix)
        target_bias = self.target_bias(target_ix)

        context_embedding = self.context_embedding(context_ix)
        context_bias = self.context_bias(context_ix)

        dot_product = self.dot_product([target_embedding, context_embedding])
        prediction = self.prediction([dot_product, target_bias, context_bias])

        return prediction

    def glove_loss(self, y_true, y_pred):

        weight = tf.math.minimum(
            tf.math.pow(y_true/self.x_max, self.a), 1.0
        )
        loss_value = tf.math.reduce_mean(weight * tf.math.pow(y_pred - tf.math.log(y_true), 2.0))

        return loss_value



I tried multiple configurations and optimizers but nothing seems to change the convergence rate.

like image 943
Benjamin Breton Avatar asked Oct 27 '19 23:10

Benjamin Breton


1 Answers

Dataset.shuffle() only shuffle each minibatch, so each epoch has the same order. Keras .fit() uses some magics to shuffle the whole dataset before each epoch. To do this in TF, you need to use Dataset .repeat(epochs_number) and .shuffle(..., reshuffle_each_iteration=True):

train_ds = data.Dataset.from_tensor_slices(
    (np.hstack([index_rows.reshape(-1, 1), index_cols.reshape(-1, 1)]), index_data)
    ).shuffle(100000, reshuffle_each_iteration=True
    ).batch(batch_size, drop_remainder=True
    ).repeat(epochs_number)

for ix, (examples, labels) in train_ds.enumerate():
    train_step(examples, labels)
    current_epoch = ix // (len(index_data) // batch_size)

This workaround is not beautiful nor natural, for the moment you can use this to shuffle each epoch. It's a known issue and will be fixed, in the future you can use for epoch in range(epochs_number) instead of .repeat().

like image 140
THN Avatar answered Nov 15 '22 08:11

THN