Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

deeplearning4j - using an RNN/LSTM for audio signal processing

I'm trying to train a RNN for digital (audio) signal processing using deeplearning4j. The idea is to have 2 .wav files: one is an audio recording, the second is the same audio recording but processed (for example with a low-pass filter). The RNN's input is the 1st (unprocessed) audio recording, the output is the 2nd (processed) audio recording.

I've used the GravesLSTMCharModellingExample from the dl4j examples, and mostly adapted the CharacterIterator class to accept audio data instead of text.

My 1st project to work with audio at all with dl4j is to basically do the same thing as GravesLSTMCharModellingExample but generating audio instead of text, working with 11025Hz 8 bit mono audio, which works (to some quite amusing results). So the basics wrt working with audio in this context seem to work.

So step 2 was to adapt this for audio processing instead of audio generation.

Unfortunately, I'm not having much success. The best it seems to be able to do is outputting a very noisy version of the input.

As a 'sanity check', I've tested using the same audio file for both the input and the output, which I expected to converge quickly to a model simply copying the input. But it doesn't. Again, after a long time of training, all it seemed to be able to do is produce a noisier version of the input.

The most relevant piece of code I guess is the DataSetIterator.next() method (adapted from the example's CharacterIterator class), which now look like this:

public DataSet next(int num) {
    if (exampleStartOffsets.size() == 0)
        throw new NoSuchElementException();

    int currMinibatchSize = Math.min(num, exampleStartOffsets.size());
    // Allocate space:
    // Note the order here:
    // dimension 0 = number of examples in minibatch
    // dimension 1 = size of each vector (i.e., number of characters)
    // dimension 2 = length of each time series/example
    // Why 'f' order here? See http://deeplearning4j.org/usingrnns.html#data
    // section "Alternative: Implementing a custom DataSetIterator"
    INDArray input = Nd4j.create(new int[] { currMinibatchSize, columns, exampleLength }, 'f');
    INDArray labels = Nd4j.create(new int[] { currMinibatchSize, columns, exampleLength }, 'f');

    for (int i = 0; i < currMinibatchSize; i++) {
        int startIdx = exampleStartOffsets.removeFirst();
        int endIdx = startIdx + exampleLength;

        for (int j = startIdx, c = 0; j < endIdx; j++, c++) {
            // inputIndices/idealIndices are audio samples converted to indices.
            // With 8-bit audio, this translates to values between 0-255.
            input.putScalar(new int[] { i, inputIndices[j], c }, 1.0);
            labels.putScalar(new int[] { i, idealIndices[j], c }, 1.0);
        }
    }

    return new DataSet(input, labels);
}

So maybe I'm having a fundamental misunderstanding of what LSTMs are supposed to do. Is there anything obviously wrong in the posted code that I'm missing? Is there an obvious reason why training on the same file doesn't necessarily converge quickly to a model that just copies the input? (let alone even trying to train it on signal processing that actually does something?)

I've seen Using RNN to recover sine wave from noisy signal which seems to be about a similar problem (but using a different ML framework), but that didn't get an answer.

Any feedback is appreciated!

like image 316
erikd71 Avatar asked May 06 '17 21:05

erikd71


1 Answers

If you hear distorted version of the input you are on the right way.

The problem might be that your free parameters of the network cannot generalize well on small number of examples. Make sure you have more samples, at least 50_000 which does not overlap each other (not from the same wav file) and try to play with network params, for example try to reduce the nodes on each layer with 10-15% and try with lower learning rate.

like image 96
Borislav Markov Avatar answered Nov 15 '22 06:11

Borislav Markov