Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras network can never classify the last class

I have been working on my project Deep Learning Language Detection which is a network with these layers to recognise from 16 programming languages:

enter image description here

And this is the code to produce the network:

# Setting up the model
graph_in = Input(shape=(sequence_length, number_of_quantised_characters))
convs = []
for i in range(0, len(filter_sizes)):
    conv = Conv1D(filters=num_filters,
                  kernel_size=filter_sizes[i],
                  padding='valid',
                  activation='relu',
                  strides=1)(graph_in)
    pool = MaxPooling1D(pool_size=pooling_sizes[i])(conv)
    flatten = Flatten()(pool)
    convs.append(flatten)

if len(filter_sizes)>1:
    out = Concatenate()(convs)
else:
    out = convs[0]

graph = Model(inputs=graph_in, outputs=out)

# main sequential model
model = Sequential()


model.add(Dropout(dropout_prob[0], input_shape=(sequence_length, number_of_quantised_characters)))
model.add(graph)
model.add(Dense(hidden_dims))
model.add(Dropout(dropout_prob[1]))
model.add(Dense(number_of_classes))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adadelta', metrics=['accuracy'])

So my last language class is SQL and in the test phase, it can never predict SQL correctly and it scores 0% on them. I thought this was due to poor quality of SQL samples (and indeed they were poor) so I removed this class and started training on 15 classes. To my surprise, now F# files had 0% detection and F# was the last class after removing SQL (i.e. the one-hot-vector where the last position is 1 and the rest is 0). Now if a network that was trained on 16 used against 15, it would achieve a very high success rate of 98.5%.

The code that I am using is pretty simple and available mainly in defs.py and data_helper.py

Here is the result of network trained with 16 classes tested against 16 classes:

Final result: 14827/16016 (0.925761738262)
xml:        995/1001 (0.994005994006)
fsharp:     974/1001 (0.973026973027)
clojure:        993/1001 (0.992007992008)
java:       996/1001 (0.995004995005)
scala:      990/1001 (0.989010989011)
python:     983/1001 (0.982017982018)
sql:        0/1001 (0.0)
js:     991/1001 (0.99000999001)
cpp:        988/1001 (0.987012987013)
css:        987/1001 (0.986013986014)
csharp:     994/1001 (0.993006993007)
go:     989/1001 (0.988011988012)
php:        998/1001 (0.997002997003)
ruby:       995/1001 (0.994005994006)
powershell:     992/1001 (0.991008991009)
bash:       962/1001 (0.961038961039)

And this is the result of the same network (trained against 16) ran against 15 classes:

Final result: 14827/15015 (0.987479187479)
xml:        995/1001 (0.994005994006)
fsharp:     974/1001 (0.973026973027)
clojure:        993/1001 (0.992007992008)
java:       996/1001 (0.995004995005)
scala:      990/1001 (0.989010989011)
python:     983/1001 (0.982017982018)
js:     991/1001 (0.99000999001)
cpp:        988/1001 (0.987012987013)
css:        987/1001 (0.986013986014)
csharp:     994/1001 (0.993006993007)
go:     989/1001 (0.988011988012)
php:        998/1001 (0.997002997003)
ruby:       995/1001 (0.994005994006)
powershell:     992/1001 (0.991008991009)
bash:       962/1001 (0.961038961039)

Has anyone else seen this? How can I get around it?

like image 434
Aliostad Avatar asked Oct 30 '17 22:10

Aliostad


1 Answers

TL;DR: The problem is that your data are not shuffled before being split into training and validation sets. Therefore, during training, all samples belonging to class "sql" are in the validation set. Your model won't learn to predict the last class if it hasn't been given samples in that class.


In get_input_and_labels(), the files for class 0 are first loaded, and then class 1, and so on. Since you set n_max_files = 2000, it means that

  • The first 2000 (or so, depends on how many files you actually have) entries in Y will be of class 0 ("go")
  • The next 2000 entries will be of class 1 ("csharp")
  • ...
  • and finally the last 2000 entries will be of the last class ("sql").

Unfortunately, Keras does not shuffle the data before splitting them into training and validation sets. Because validation_split is set to 0.1 in your code, about the last 3000 samples (which contains all the "sql" samples) will be in the validation set.

If you set validation_split to a higher value (e.g., 0.2), you'll see more classes scoring 0%:

Final result: 12426/16016 (0.7758491508491508)
go:             926/1001 (0.9250749250749251)
csharp:         966/1001 (0.965034965034965)
java:           973/1001 (0.972027972027972)
js:             929/1001 (0.9280719280719281)
cpp:            986/1001 (0.985014985014985)
ruby:           942/1001 (0.9410589410589411)
powershell:             981/1001 (0.98001998001998)
bash:           882/1001 (0.8811188811188811)
php:            977/1001 (0.9760239760239761)
css:            988/1001 (0.987012987012987)
xml:            994/1001 (0.993006993006993)
python:         986/1001 (0.985014985014985)
scala:          896/1001 (0.8951048951048951)
clojure:                0/1001 (0.0)
fsharp:         0/1001 (0.0)
sql:            0/1001 (0.0)

The problem can be solved if you shuffle the data after loading. It seems that you already have lines shuffling the data:

# Shuffle data
shuffle_indices = np.random.permutation(np.arange(len(y)))
x_shuffled = x[shuffle_indices]
y_shuffled = y[shuffle_indices].argmax(axis=1)

However, when you fit the model, you passed the original x and y to fit() instead of x_shuffled and y_shuffled. If you change the line into:

model.fit(x_shuffled, y_shuffled, batch_size=batch_size,
          epochs=num_epochs, validation_split=val_split, verbose=1)

The testing output would become more reasonable:

Final result: 15248/16016 (0.952047952047952)
go:             865/1001 (0.8641358641358642)
csharp:         986/1001 (0.985014985014985)
java:           977/1001 (0.9760239760239761)
js:             953/1001 (0.952047952047952)
cpp:            974/1001 (0.973026973026973)
ruby:           985/1001 (0.984015984015984)
powershell:             974/1001 (0.973026973026973)
bash:           942/1001 (0.9410589410589411)
php:            979/1001 (0.978021978021978)
css:            965/1001 (0.964035964035964)
xml:            988/1001 (0.987012987012987)
python:         857/1001 (0.8561438561438561)
scala:          955/1001 (0.954045954045954)
clojure:                985/1001 (0.984015984015984)
fsharp:         950/1001 (0.949050949050949)
sql:            913/1001 (0.9120879120879121)
like image 173
Yu-Yang Avatar answered Oct 20 '22 15:10

Yu-Yang