I have trained a text classifier using google BERT via tf.keras.
My dataset consists of 50,000 rows of data, spread evenly over 5 labels. This is a subset of a much larger dataset, but I have chosen these specific labels as they are totally different from each other to try to avoid confusion during training.
I create the data splits as follows:
train, test = train_test_split(df, test_size=0.30, shuffle=True, stratify=df['label'], random_state=10)
train, val = train_test_split(train, test_size=0.1, shuffle=True, stratify=train['label'], random_state=10)
The model design is:
def compile():
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
learn_rate = 4e-5
bert = 'bert-base-uncased'
model = TFBertModel.from_pretrained(bert, trainable=False)
input_ids_layer = Input(shape=(512,), dtype=np.int32)
input_mask_layer = Input(shape=(512,), dtype=np.int32)
bert_layer = model([input_ids_layer, input_mask_layer])[0]
X = tf.keras.layers.GlobalMaxPool1D()(bert_layer)
output = Dense(5)(X)
output = BatchNormalization(trainable=False)(output)
output = Activation('softmax')(output)
model_ = Model(inputs=[input_ids_layer, input_mask_layer], outputs=output)
optimizer = tf.keras.optimizers.Adam(4e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metric = tf.keras.metrics.SparseCategoricalAccuracy('accuracy')
model_.compile(optimizer=optimizer, loss=loss, metrics=[metric])
return model_
Which gives the following results:
loss: 1.2433
accuracy: 0.8024
val_loss: 1.2148
val_accuracy: 0.8300
f1_score: 0.8283
precision: 0.8300
recall: 0.8286
auc: 0.9676
When I run the test data through, and converting the one-hot encoded labels back to their original labels (having used model.load_weights()
)...
test_sample = [test_dataset[0],test_dataset[1], test_dataset[2]]
predictions = tf.argmax(model.predict(test_sample[:2]), axis =1)
preds_inv = le.inverse_transform(predictions)
true_inv = le.inverse_transform(test_sample[2])
... the confusion matrix is all over the place:
confusion_matrix(true_inv, inv_preds)
array([[ 967, 202, 7, 685, 1139],
[ 474, 785, 27, 717, 997],
[ 768, 372, 46, 1024, 790],
[ 463, 426, 27, 1272, 812],
[ 387, 224, 11, 643, 1735]])
Interestingly, the third label is hardly being predicted at all.
Note that I set trainable=False in the batch normalization, but during training this is set to true.
The input data consists of two arrays: the numerical vector representation of the text strings (embeddings) and the padding tokens used to identify which of the 512 elements for each string are padded values.
What could be the causes of giving reasonable accuracy scores, but terrible predictions, on an evenly balanced dataset, trained with a deep pre-trained model (bert)?
In my specific case I resolved this by investigating the content of the 2 labels causing confusion. I did this using wordcloud. The example below shows my code for one of the labels:
from os import path
from PIL import Image
from wordcloud import WordCloud, STOPWORDS, ImageColorGenerator
import matplotlib.pyplot as plt
% matplotlib inline
df1 = df[df['label']==48000000]
text = " ".join(review for review in df1.text)
wordcloud = WordCloud().generate(text)
plt.imshow(wordcloud, interpolation='bilinear')
plt.axis("off")
plt.show()
Now, it was my understanding that BERT should be able to recognise which words are important to a particular label (using something like TF-IDF? not sure), however, when I remove stopwords using NLTK, and also by adding to that list the words which I think are generic to my specific dataset, in this case 'system', 'service' (etc), upon retraining a model the accuracies go up significantly:
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
def preprocess_text(sentence):
# Convert to lowercase
sentence = sentence.lower()
new_stopwords = ['service','contract','solution','county','supplier',
'district','council','borough','management',
'provider','provision'
'project','contractor']
stop_words = set(stopwords.words('english'))
stop_words.update(new_stopwords)
sentence = [w for w in sentence.split(" ") if not w in stop_words]
sentence = ' '.join(w for w in sentence)
return sentence
df['text'] = df['text'].apply(preprocess_text)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With