Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras BERT - high accuracy, validation acc, f1, auc -> but predictions poor

I have trained a text classifier using google BERT via tf.keras.

My dataset consists of 50,000 rows of data, spread evenly over 5 labels. This is a subset of a much larger dataset, but I have chosen these specific labels as they are totally different from each other to try to avoid confusion during training.

I create the data splits as follows:

train, test = train_test_split(df, test_size=0.30, shuffle=True, stratify=df['label'], random_state=10)
train, val = train_test_split(train, test_size=0.1, shuffle=True, stratify=train['label'], random_state=10)

The model design is:

def compile():
    mirrored_strategy = tf.distribute.MirroredStrategy()
    with mirrored_strategy.scope():
        learn_rate = 4e-5
        bert = 'bert-base-uncased'
        model = TFBertModel.from_pretrained(bert, trainable=False)

        input_ids_layer = Input(shape=(512,), dtype=np.int32)
        input_mask_layer = Input(shape=(512,), dtype=np.int32)

        bert_layer = model([input_ids_layer, input_mask_layer])[0]

        X = tf.keras.layers.GlobalMaxPool1D()(bert_layer)

        output = Dense(5)(X)
        output = BatchNormalization(trainable=False)(output)
        output = Activation('softmax')(output)

        model_ = Model(inputs=[input_ids_layer, input_mask_layer], outputs=output)

        optimizer = tf.keras.optimizers.Adam(4e-5)
        loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
        metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')

        model_.compile(optimizer=optimizer, loss=loss, metrics=[metric])
        return model_

Which gives the following results:

loss: 1.2433
accuracy: 0.8024
val_loss: 1.2148
val_accuracy: 0.8300
f1_score: 0.8283
precision: 0.8300
recall: 0.8286
auc: 0.9676

When I run the test data through, and converting the one-hot encoded labels back to their original labels (having used model.load_weights())...

test_sample = [test_dataset[0],test_dataset[1], test_dataset[2]]
predictions = tf.argmax(model.predict(test_sample[:2]), axis =1)
preds_inv = le.inverse_transform(predictions)
true_inv = le.inverse_transform(test_sample[2])

... the confusion matrix is all over the place:

confusion_matrix(true_inv, inv_preds)

array([[ 967,  202,    7,  685, 1139],
       [ 474,  785,   27,  717,  997],
       [ 768,  372,   46, 1024,  790],
       [ 463,  426,   27, 1272,  812],
       [ 387,  224,   11,  643, 1735]])

Interestingly, the third label is hardly being predicted at all.

Note that I set trainable=False in the batch normalization, but during training this is set to true.

The input data consists of two arrays: the numerical vector representation of the text strings (embeddings) and the padding tokens used to identify which of the 512 elements for each string are padded values.

What could be the causes of giving reasonable accuracy scores, but terrible predictions, on an evenly balanced dataset, trained with a deep pre-trained model (bert)?

like image 413
ML_Engine Avatar asked Sep 08 '25 09:09

ML_Engine


1 Answers

In my specific case I resolved this by investigating the content of the 2 labels causing confusion. I did this using wordcloud. The example below shows my code for one of the labels:

from os import path
from PIL import Image
from wordcloud import WordCloud, STOPWORDS, ImageColorGenerator
import matplotlib.pyplot as plt
% matplotlib inline

df1 = df[df['label']==48000000]
text = " ".join(review for review in df1.text)
wordcloud = WordCloud().generate(text)
plt.imshow(wordcloud, interpolation='bilinear')
plt.axis("off")
plt.show()

enter image description here

Now, it was my understanding that BERT should be able to recognise which words are important to a particular label (using something like TF-IDF? not sure), however, when I remove stopwords using NLTK, and also by adding to that list the words which I think are generic to my specific dataset, in this case 'system', 'service' (etc), upon retraining a model the accuracies go up significantly:

import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

def preprocess_text(sentence):

    # Convert to lowercase
    sentence = sentence.lower()

    new_stopwords = ['service','contract','solution','county','supplier',
             'district','council','borough','management',
             'provider','provision'
              'project','contractor']

    stop_words = set(stopwords.words('english'))
    stop_words.update(new_stopwords)
    sentence = [w for w in sentence.split(" ") if not w in stop_words]
    sentence = ' '.join(w for w in sentence)
return sentence

df['text'] = df['text'].apply(preprocess_text)

enter image description here

like image 172
ML_Engine Avatar answered Sep 10 '25 04:09

ML_Engine