Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Poor performance of seq-to-seq LSTM on simple sin wave with low frequency

I am trying to train a seq-to-seq model on a simple sin wave. The target is to get Nin points of data and predict Nout next data points. Task seems simple and the model predicts well for large frequency freq (y = sin(freq * x)). For example, for freq=4, the loss is very low and the prediction is very close to the target. However, for low frequencies the prediction is bad. Any thoughts on why the model fails?

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, LSTM, RepeatVector, TimeDistributed, Dense

freq = 0.25
Nin, Nout = 14, 14

# Helper function to convert 1d data to (input, target) samples
def windowed_dataset(y, input_window = 5, output_window = 1, stride = 1, num_features = 1):
    L = y.shape[0]
    num_samples = (L - input_window - output_window) // stride + 1
    X = np.zeros([input_window, num_samples, num_features])
    Y = np.zeros([output_window, num_samples, num_features])    
    for ff in np.arange(num_features):
        for ii in np.arange(num_samples):
            start_x = stride * ii
            end_x = start_x + input_window
            X[:, ii, ff] = y[start_x:end_x, ff]
            start_y = stride * ii + input_window
            end_y = start_y + output_window 
            Y[:, ii, ff] = y[start_y:end_y, ff]
    return X, Y

# The input shape is your sequence length and your token embedding size
inputs = Input(shape=(Nin, 1))
# Build a RNN encoder
encoder = LSTM(128, return_sequences=False)(inputs)
# Repeat the encoding for every input to the decoder
encoding_repeat = RepeatVector(Nout)(encoder)
# Pass your (5, 128) encoding to the decoder
decoder = LSTM(128, return_sequences=True)(encoding_repeat)
# Output each timestep into a fully connected layer
sequence_prediction = TimeDistributed(Dense(1, activation='linear'))(decoder)
model = Model(inputs, sequence_prediction)
model.compile('adam', 'mse')  # Or categorical_crossentropy
y = np.sin(freq * np.linspace(0, 10, 1000))[:, None]
Ntr = int(0.8 * y.shape[0])
y_train, y_test = y[:Ntr], y[Ntr:]
from generate_dataset import *
stride = 1
N_features = 1
Xtrain, Ytrain = windowed_dataset(y_train, input_window=Nin, output_window=Nout, stride=stride,
                                  num_features=N_features)
print(model.summary())
Xtrain, Ytrain = Xtrain.transpose(1, 0, 2), Ytrain.transpose(1, 0, 2)
print("Xtrain", Xtrain.shape)
model.fit(Xtrain, Ytrain, epochs=30)
plt.figure(); plt.plot(y, 'ro')
for Ns in arr([10, 50, 200, 400, 800, 1500, 3000]) // 10:
    ypred = model.predict(Xtrain[[Ns]])
    print("ypred", ypred.shape)
    ypred = ypred[-1]
    plt.figure()
    plt.plot(ypred, 'ro')
    plt.plot(Xtrain[Ns], 'm--')
    plt.plot(Ytrain[Ns], 'k.')
    plt.show()
exit()
like image 457
Roy Avatar asked Nov 07 '22 01:11

Roy


1 Answers

I think because the lower you get the less pattern it got. Ea thinks of it as you got a pattern on X inputs to predict the next output. although al x(n) input slightly raises in value there is hardly a pattern. a slight raise happened earlier before too, so nothing new got learned no new patterns. It would take a longer training time to see it as x waves passing by, to count as a pattern.

It is interesting though if you take the same amount of training. But skip forward on the sin line, or far easier, use your good working model, and then test it with divided inputs. ea: if you trained it with degrees 5,10,15,20,25 etc.. give that trained network 0.05 0.10 degrees (ea alter the inputs only, but keep the network).

Resume that sequence trainer networks work well on data that has patterns like language text prediction etc) but not on data that has a little pattern.

---edit -- (too long for in comment reply)--
Yes, it's hard to debug neural nets, though I think you have to go back to basic principles, is a raising signal a pattern, it can only be detected if it raises (enough) up and down in the training. Rnn and lstms are good in serial patterns ea asci strings, slow gliding numbers as for a test case are hardly a pattern to refer back to in the memory. Perhaps you can improve here by altering training sample order, so take a random position in on the sinewave, as the internal "narrowing/nearing" error correction might get over convinced into a certain direction, cause its last 70 samples went up why would 71 go down. to handle it better.

like image 184
Peter Avatar answered Nov 14 '22 22:11

Peter