Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras LSTM for Text Generation keeps repeating a line or a sequence

I roughly followed this tutorial:

https://machinelearningmastery.com/text-generation-lstm-recurrent-neural-networks-python-keras/

A notable difference is that I use 2 LSTM layers with dropout. My data set is different (music data-set in abc notation). I do get some songs generated, but after a certain number of steps (may range from 30 steps to a couple hundred) in the generation process, the LSTM keeps generating the exact same sequence over and over again. For example, it once got stuck with generating URLs for songs:

F: http://www.youtube.com/watch?v=JPtqU6pipQI

and so on ...

It also once got stuck with generating the same two songs (the two songs are a sequence of about 300 characters). In the beginning it generated 3-4 good pieces but afterwards, it kept regenerating the two songs almost indefinitely.

I am wondering, does anyone have some insight into what could be happening ?

I want to clarify that any sequence generated whether repeating or non-repeating seems to be new (model is not memorising). The validation loss and training loss decrease as expected. Andrej Karpathy is able to generate a document of thousands of characters and I couldn't find this pattern of getting stuck indefinitely.

http://karpathy.github.io/2015/05/21/rnn-effectiveness/

like image 777
oneThousandHertz Avatar asked Nov 05 '17 19:11

oneThousandHertz


2 Answers

Instead of taking the argmax on the prediction output, try introducing some randomness with something like this:

np.argmax(prediction_output)[0])

to

np.random.choice(len(prediction_output), p=prediction_output)

I've been struggling on this repeating sequences issue for a while until I discovered this Colab notebook where I figured out why their model was able to generate some really good samples: https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/shakespeare_with_tpu_and_keras.ipynb#scrollTo=tU7M-EGGxR3E

After I changed this single line, my model went from generating a few words over and over to something actually interesting!

like image 188
Shane Smiskol Avatar answered Sep 28 '22 12:09

Shane Smiskol


To use and train a text generation model follow these steps:

  1. Drawing from the model a probability distribution over the next character given the text available so far ( This would be our predictions scores )
  2. Reweighting the distribution to a certain "temperature" (See the code below)
  3. Sampling the next character at random according to the reweighted distribution (See the code below)
  4. Adding the new character at the end of the available text

See the sample function:

def sample(preds, temperature=1.0):
    preds = np.asarray(preds).astype('float64')
    preds = np.log(preds) / temperature
    exp_preds = np.exp(preds)
    preds = exp_preds / np.sum(exp_preds)
    probas = np.random.multinomial(1, preds, 1)
    return np.argmax(probas)

You should use the sample function during training as follows:

for epoch in range(1, 60):
    print('epoch', epoch)
    # Fit the model for 1 epoch on the available training data
    model.fit(x, y,
              batch_size=128,
              epochs=1)

    # Select a text seed at random
    start_index = random.randint(0, len(text) - maxlen - 1)
    generated_text = text[start_index: start_index + maxlen]
    print('--- Generating with seed: "' + generated_text + '"')

    for temperature in [0.2, 0.5, 1.0, 1.2]:
        print('------ temperature:', temperature)
        sys.stdout.write(generated_text)

        # We generate 400 characters
        for i in range(400):
            sampled = np.zeros((1, maxlen, len(chars)))
            for t, char in enumerate(generated_text):
                sampled[0, t, char_indices[char]] = 1.

            preds = model.predict(sampled, verbose=0)[0]
            next_index = sample(preds, temperature)
            next_char = chars[next_index]

            generated_text += next_char
            generated_text = generated_text[1:]

            sys.stdout.write(next_char)
            sys.stdout.flush()
        print()

A low temperature results in extremely repetitive and predictable text, but where local structure is highly realistic: in particular, all words (a word being a local pattern of characters) are real English words. With higher temperatures, the generated text becomes more interesting, surprising, even creative.

See this notebook

like image 44
Guillem Avatar answered Sep 28 '22 12:09

Guillem