Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Autoencoder not learning identity function

I'm somewhat new to machine learning in general, and I wanted to make a simple experiment to get more familiar with neural network autoencoders: To make an extremely basic autoencoder that would learn the identity function.

I'm using Keras to make life easier, so I did this first to make sure it works:

# Weights are given as [weights, biases], so we give
# the identity matrix for the weights and a vector of zeros for the biases
weights = [np.diag(np.ones(84)), np.zeros(84)]
model = Sequential([Dense(84, input_dim=84, weights=weights)])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(X, X, nb_epoch=10, batch_size=8, validation_split=0.3)

As expected, the loss is zero, both in train and validation data:

Epoch 1/10
97535/97535 [==============================] - 27s - loss: 0.0000e+00 - val_loss: 0.0000e+00
Epoch 2/10
97535/97535 [==============================] - 28s - loss: 0.0000e+00 - val_loss: 0.0000e+00

Then I tried to do the same but without initializing the weights to the identity function, expecting that after a while of training it would learn it. It didn't. I've let it run for 200 epochs various times in different configurations, playing with different optimizers, loss functions, and adding L1 and L2 activity regularizers. The results vary, but the best I've got is still really bad, looking nothing like the original data, just being kinda in the same numeric range. The data is simply some numbers oscillating around 1.1. I don't know if an activation layer makes sense for this problem, should I be using one?

If this "neural network" of one layer can't learn something as simple as the identity function, how can I expect it to learn anything more complex? What am I doing wrong?

EDIT

To have better context, here's a way to generate a dataset very similar to the one I'm using:

X = np.random.normal(1.1090579, 0.0012380764, (139336, 84))

I'm suspecting that the variations between the values might be too small. The loss function ends up having decent values (around 1e-6), but it's not enough precision for the result to have a similar shape to the original data. Maybe I should scale/normalize it somehow? Thanks for any advice!

UPDATE

In the end, as it was suggested, the issue was with the dataset having too small variations between the 84 values, so the resulting prediction was actually pretty good in absolute terms (loss function) but comparing it to the original data, the variations were far off. I solved it by normalizing the 84 values in each sample around the sample's mean and dividing by the sample's standard deviation. Then I used the original mean and standard deviation to denormalize the predictions at the other end. I guess this could be done in a few different ways, but I did it by adding this normalization/denormalization into the model itself by using some Lambda layers that operated on the tensors. That way all the data processing was incorporated into the model, which made it nicer to work with. Let me know if you would like to see the actual code.

like image 679
Andres Berrios Avatar asked Jan 10 '17 00:01

Andres Berrios


1 Answers

I believe the problem could be either the number of epoch or the way you inizialize X. I ran your code with an X of mine for 100 epochs and printed the argmax() and max values of the weights, it gets really close to the identity function.

I'm adding the code snippet that I used

from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import random
import pandas as pd

X = np.array([[random.random() for r in xrange(84)] for i in xrange(1,100000)])
model = Sequential([Dense(84, input_dim=84)], name="layer1")
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(X, X, nb_epoch=100, batch_size=80, validation_split=0.3)

l_weights = np.round(model.layers[0].get_weights()[0],3)

print l_weights.argmax(axis=0)
print l_weights.max(axis=0)

And I'm getting:

Train on 69999 samples, validate on 30000 samples
Epoch 1/100
69999/69999 [==============================] - 1s - loss: 0.2092 - val_loss: 0.1564
Epoch 2/100
69999/69999 [==============================] - 1s - loss: 0.1536 - val_loss: 0.1510
Epoch 3/100
69999/69999 [==============================] - 1s - loss: 0.1484 - val_loss: 0.1459
.
.
.
Epoch 98/100
69999/69999 [==============================] - 1s - loss: 0.0055 - val_loss: 0.0054
Epoch 99/100
69999/69999 [==============================] - 1s - loss: 0.0053 - val_loss: 0.0053
Epoch 100/100
69999/69999 [==============================] - 1s - loss: 0.0051 - val_loss: 0.0051
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83]
[ 0.85000002  0.85100001  0.79799998  0.80500001  0.82700002  0.81900001
  0.792       0.829       0.81099999  0.80800003  0.84899998  0.829       0.852
  0.79500002  0.84100002  0.81099999  0.792       0.80800003  0.85399997
  0.82999998  0.85100001  0.84500003  0.847       0.79699999  0.81400001
  0.84100002  0.81        0.85100001  0.80599999  0.84500003  0.824
  0.81999999  0.82999998  0.79100001  0.81199998  0.829       0.85600001
  0.84100002  0.792       0.847       0.82499999  0.84500003  0.796
  0.82099998  0.81900001  0.84200001  0.83999997  0.815       0.79500002
  0.85100001  0.83700001  0.85000002  0.79900002  0.84100002  0.79699999
  0.838       0.847       0.84899998  0.83700001  0.80299997  0.85399997
  0.84500003  0.83399999  0.83200002  0.80900002  0.85500002  0.83899999
  0.79900002  0.83399999  0.81        0.79100001  0.81800002  0.82200003
  0.79100001  0.83700001  0.83600003  0.824       0.829       0.82800001
  0.83700001  0.85799998  0.81999999  0.84299999  0.83999997]

When I used only 5 numbers as an input and printed the actual weights I got this:

array([[ 1.,  0., -0.,  0.,  0.],
       [ 0.,  1.,  0., -0., -0.],
       [-0.,  0.,  1.,  0.,  0.],
       [ 0., -0.,  0.,  1., -0.],
       [ 0., -0.,  0., -0.,  1.]], dtype=float32)
like image 198
Ohad Zadok Avatar answered Oct 04 '22 03:10

Ohad Zadok