Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

CNN training accuracy stagnates with BatchNorm, quickly overfits without

I have 2 types of grayscale images, let's say a car and a plane. In my training set, I have 1000 images (about a 50/50 split). In this training set, all of my plane examples are on a white background, whereas all of car examples is on a black background (this is done on purpose and the model ultimately learn to differentiate between a car and a plane, not their backgrounds).

As a simple proof that a model will quickly overfit to the backgrounds, I created a CNN. However, I'm running into 2 weird scenarios:

  1. If I add BatchNorm anywhere between a conv layer and another layer, my training accuracy seems to hover around 50% and can't improve.

  2. If I remove BatchNorm, my training accuracy quickly skyrockets to 98%ish. Despite me using my training dataset to create a validation dataset (thus, this validation dataset also has the black/white background issue), my validation dataset hovers around 50%. I would expect my training dataset overfit to be caused by the black and white backgrounds, which my validation dataset also has and would be able to predict against.

I've attached my code. I get the data as a 1x4096 vector, so I reshape it into a 64x64 image. When I uncomment any of the BatchNorm steps in my code below, the training accuracy seems to hover

#Normalize training data
        self.x = self.x.astype('float32')
        self.x /= 255

        numSamples = self.x.shape[0]
        #Reconstruct images
        width = 64
        height = 64
        xInput = self.x.reshape(numSamples,1,height,width)

        y_test = to_categorical(labels, 2)

        #Split data to get validation set
        X_train, X_test, y_train, y_test = train_test_split(xInput, y_test, test_size=0.3, random_state=0)

        #Construct model
        self.model = Sequential()
        self.model.add(Conv2D(64, kernel_size=(3, 3), strides=(1, 1),
                 activation='relu',
                 input_shape=(1,64,64), data_format='channels_first',activity_regularizer=regularizers.l1(0.01)))
        #self.model.add(BatchNormalization())
        self.model.add(MaxPooling2D((2, 2)))
        self.model.add(Dropout(0.5, noise_shape=None)) 
        self.model.add(Conv2D(128, kernel_size=(3, 3), strides=(1, 1), activation='relu'))
        #self.model.add(BatchNormalization())
        self.model.add(MaxPooling2D((2, 2)))
        self.model.add(Dropout(0.5, noise_shape=None)) 
        self.model.add(Conv2D(256, kernel_size=(3, 3), strides=(1, 1), activation='relu'))
        #self.model.add(BatchNormalization())
        self.model.add(MaxPooling2D((2, 2)))
        self.model.add(Dropout(0.5, noise_shape=None)) 
        self.model.add(Flatten())
        self.model.add(Dense(1000, activation='relu', activity_regularizer=regularizers.l2(0.01)))
        self.model.add(BatchNormalization())
        self.model.add(Dropout(0.5, noise_shape=None)) 
        self.model.add(Dense(units = 2, activation = 'softmax', kernel_initializer='lecun_normal'))

        self.model.compile(loss='categorical_crossentropy',
             optimizer='adam',
             metrics=['accuracy'])

        self.model.fit(X_train, y_train,
            batch_size=32,
            epochs=25,
            verbose=2,
            validation_data = (X_test,y_test),
            callbacks = [EarlyStopping(monitor = 'val_acc', patience =5)])
like image 888
Kevin Avatar asked Nov 08 '22 03:11

Kevin


1 Answers

I think there are a number of potential improvements to the architecture of your ANN and some fundamental problem.

Fundamental challenge is with the way your training set has been built: black & white background. If the intention was that the background should not play a role, why not making all of them white or black? Mind that ANN, like close to any machine learning algorithm, will attempt to find what differentiates your classes. And in this case it will be simply background. Why look at tiny details of car vs. air plane, when background provides so clear and rewarding differentiation?

Solution: Make background uniform for both sets. Then your ANN will be oblivious to it.

Why Batch Norm was messing up training accuracy? As you noted yourself, test accuracy was still poor. Batch Norm was fixing covariance shift problem. The "problem" was manifesting later in seemingly great training accuracy - and poor test. Great video on Batch Normalisation, with piece on covaraince shift, from Andrew Ng here.

Fixing training should fix the issue. Some other things:

  • At the very end you give 2 dense units, but your classification is binary. Make it a single unit with sigmoid activation.
  • As pointed out by @Upasana Mittal, replace categorical_crossentropy with binary_crossentropy.
  • Consider using smaller dropout rates. Mind you don't have that much data to always discard half of it. Increase dropout only after you have evidence of overfitting.
  • Using Conv2D with strides can be better than simple max pooling.
  • You have a lot of filters for what does not seem to be that super complicated. Consider severe reduction in number of filters and increase the number only when you see that the ANN has not enough capacity for learning. You have only 2 classes here and the features differentiating car from a jet are not that subtle.
  • Consider using smaller number of layers. Same argument.
  • Using at least 2 stacked 3x3 Conv2D layers can yield better results.
like image 161
Lukasz Tracewski Avatar answered Nov 13 '22 06:11

Lukasz Tracewski