Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Unexpected results with CuDNNLSTM (instead of LSTM) layer

I have posted this question as an issue in Keras' Github but figured it might reach a broader audience here.


System information

  • Have I written custom code (as opposed to using example directory): Minimal change to official Keras tutorial
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 18.04.2 LTS
  • TensorFlow backend (yes / no): yes
  • TensorFlow version: 1.13.1
  • Keras version: 2.2.4
  • Python version: 3.6.5
  • CUDA/cuDNN version: 10.1
  • GPU model and memory: Tesla K80 11G

Describe the current behavior
I am executing the code from the Seq2Seq tutorial. The one and only change I made was to swap the LSTM layers for CuDNNLSTM. What happens is that the model predicts a fixed output for any input I give it. When I run the original code, I get sensible results.

Describe the expected behavior
See preceding section.

Code to reproduce the issue
Taken from here. Simply replace LSTM with CuDNNLSTM.


Any insights are greatly appreciated.

like image 621
Orest Xherija Avatar asked Jun 28 '19 05:06

Orest Xherija


Video Answer


1 Answers

So here there are two problems.
Use of CuDNNLSTM and parameter tuning.
Basically, the network overfits on your dataset which leads the output being only one sentence for every input. This is neither the fault of CuDNNLSTM nor LSTM.

Firstly,
CuDNN has a bit different maths from regular LSTM to make it Cuda Compatible and run faster. The LSTM takes 11 sec to run on eng-hindi file for the same code that you used and CuDNNLSTM takes 1 sec for each epoch.

In the CuDNNLSTM time_major param is set to false. For this reason the network overfits. You can check it here.
You can clearly see for small datasets like eng-hin or eng-marathi the val-loss increases after 30 epochs. There is no point in running the network more where your network loss is decreasing and val_loss is increasing. The case with LSTM is same too.

Here you need param tuning for small datasets.

Here are a few links which can help:

  1. Eng-Mar
  2. Pytorch translation tutorial
  3. Similar Question 2 and Similar Question 2
  4. NMT-keras
like image 182
ASHu2 Avatar answered Oct 13 '22 00:10

ASHu2