Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to calculate perplexity of RNN in tensorflow

I'm running the word RNN implmentation of tensor flow of Word RNN

How to calculate the perplexity of RNN.

Following is the code in training that shows training loss and other things in each epoch:

for e in range(model.epoch_pointer.eval(), args.num_epochs):
        sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e)))
        data_loader.reset_batch_pointer()
        state = sess.run(model.initial_state)
        speed = 0
        if args.init_from is None:
            assign_op = model.batch_pointer.assign(0)
            sess.run(assign_op)
            assign_op = model.epoch_pointer.assign(e)
            sess.run(assign_op)
        if args.init_from is not None:
            data_loader.pointer = model.batch_pointer.eval()
            args.init_from = None
        for b in range(data_loader.pointer, data_loader.num_batches):
            start = time.time()
            x, y = data_loader.next_batch()
            feed = {model.input_data: x, model.targets: y, model.initial_state: state,
                    model.batch_time: speed}
            summary, train_loss, state, _, _ = sess.run([merged, model.cost, model.final_state,
                                                         model.train_op, model.inc_batch_pointer_op], feed)
            train_writer.add_summary(summary, e * data_loader.num_batches + b)
            speed = time.time() - start
            if (e * data_loader.num_batches + b) % args.batch_size == 0:
                print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(e * data_loader.num_batches + b,
                            args.num_epochs * data_loader.num_batches,
                            e, train_loss, speed))
            if (e * data_loader.num_batches + b) % args.save_every == 0 \
                    or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result
                checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')
                saver.save(sess, checkpoint_path, global_step = e * data_loader.num_batches + b)
                print("model saved to {}".format(checkpoint_path))
    train_writer.close()
like image 210
Shan Khan Avatar asked Jan 26 '17 19:01

Shan Khan


1 Answers

The project you are referencing uses sequence_to_sequence_loss_by_example, which returns the cross-entropy loss. So for calculating the training perplexity, you just need to exponentiate the loss like explained here.

train_perplexity = tf.exp(train_loss)

We have to use e instead of 2 as a base, because TensorFlow measures the cross-entropy loss with the natural logarithm (TF Documentation). Thank you, @Matthias Arro and @Colin Skow for the hint.

Detailed Explanation

The cross-entropy of two probability distributions P and Q tells us the minimum average number of bits we need to encode events of P, when we develop a coding scheme based on Q. So, P is the true distribution, which we usually don't know. We want to find a Q as close to P as possible, so that we can develop a nice coding scheme with as few bits per event as possible.

I shouldn't say bits, because we can only use bits as a measure if we use base 2 in the calculation of the cross-entropy. But TensorFlow uses the natural logarithm, so instead let's measure the cross-entropy in nats.

So let's say we have a bad language model that says every token (character / word) in the vocabulary is equally probable to be the next one. For a vocabulary of 1000 tokens, this model will have a cross-entropy of log(1000) = 6.9 nats. When predicting the next token, it has to choose uniformly between 1000 tokens at each step.

A better language model will determine a probability distribution Q that is closer to P. Thus, the cross-entropy is lower - we might get a cross-entropy of 3.9 nats. If we now want to measure the perplexity, we simply exponentiate the cross-entropy:

exp(3.9) = 49.4

So, on the samples, for which we calculated the loss, the good model was as perplex as if it had to choose uniformly and independently among roughly 50 tokens.

like image 148
Kilian Batzner Avatar answered Sep 30 '22 15:09

Kilian Batzner