Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras.model.summary does not correctly display my model..?

I want to see my model's summary through keras.model.summary, but it doesn't work well. My code is as below:

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32,3,activation = 'relu')
        self.flatten = Faltten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10, activation = 'relu')

    def trythis(self,x):
        a = BatchNormalization()
        b = a(x)
        return b

    def call(self, x):
        x = self.conv1(x)
        x = trythis(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

model = MyModel()
model.build((None, 32,32,3))
model.summary()

I expected BatchNorm layer, but summary is as below:

Model: "my_model_30"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d_31 (Conv2D)           multiple                  896       
_________________________________________________________________
flatten_30 (Flatten)         multiple                  0         
_________________________________________________________________
dense_60 (Dense)             multiple                  3686528   
_________________________________________________________________
dense_61 (Dense)             multiple                  1290      
=================================================================
Total params: 3,688,714
Trainable params: 3,688,714
Non-trainable params: 0

It does not contain the BatchNorm layer in 'trythis' method.

How can I solve this problem?

Thank you for reading.

like image 798
Geonsu Kim Avatar asked Oct 16 '22 04:10

Geonsu Kim


1 Answers

Shape inference of subclassed model is not automatic as in Functional API. So I added a model call within the subclassed model and defined a functional model as shown below. Please note that there are couple ways to do and what I am showing is one way. Please check more details at similar question that I answered here

import tensorflow as tf
from tensorflow import keras

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, BatchNormalization

class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32,3,activation = 'relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation = 'relu')
        self.d2 = Dense(10, activation = 'relu')

    def trythis(self,x):
        a = BatchNormalization()
        b = a(x)
        return b

    def call(self, x):
        x = self.conv1(x)
        x = MyModel.trythis(self,x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)
    def model(self):
        x = tf.keras.layers.Input(shape=(32, 32, 3))
        return Model(inputs=[x], outputs=self.call(x))

model = MyModel()
model_functional = model.model()
#model.build((None, 32,32,3))
model_functional.summary()

Summary is as follows

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_5 (InputLayer)         [(None, 32, 32, 3)]       0         
_________________________________________________________________
conv2d_5 (Conv2D)            (None, 30, 30, 32)        896       
_________________________________________________________________
batch_normalization (BatchNo (None, 30, 30, 32)        128       
_________________________________________________________________
flatten_4 (Flatten)          (None, 28800)             0         
_________________________________________________________________
dense_8 (Dense)              (None, 128)               3686528   
_________________________________________________________________
dense_9 (Dense)              (None, 10)                1290      
=================================================================
Total params: 3,688,842
Trainable params: 3,688,778
Non-trainable params: 64
_________________________________________________________________
like image 112
Vishnuvardhan Janapati Avatar answered Nov 15 '22 07:11

Vishnuvardhan Janapati