I want to see my model's summary through keras.model.summary, but it doesn't work well. My code is as below:
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32,3,activation = 'relu')
self.flatten = Faltten()
self.d1 = Dense(128, activation = 'relu')
self.d2 = Dense(10, activation = 'relu')
def trythis(self,x):
a = BatchNormalization()
b = a(x)
return b
def call(self, x):
x = self.conv1(x)
x = trythis(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
model = MyModel()
model.build((None, 32,32,3))
model.summary()
I expected BatchNorm layer, but summary is as below:
Model: "my_model_30"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d_31 (Conv2D) multiple 896
_________________________________________________________________
flatten_30 (Flatten) multiple 0
_________________________________________________________________
dense_60 (Dense) multiple 3686528
_________________________________________________________________
dense_61 (Dense) multiple 1290
=================================================================
Total params: 3,688,714
Trainable params: 3,688,714
Non-trainable params: 0
It does not contain the BatchNorm layer in 'trythis' method.
How can I solve this problem?
Thank you for reading.
Shape inference of subclassed model is not automatic as in Functional API. So I added a model call within the subclassed model and defined a functional model as shown below. Please note that there are couple ways to do and what I am showing is one way. Please check more details at similar question that I answered here
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Dense, Flatten, BatchNormalization
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32,3,activation = 'relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation = 'relu')
self.d2 = Dense(10, activation = 'relu')
def trythis(self,x):
a = BatchNormalization()
b = a(x)
return b
def call(self, x):
x = self.conv1(x)
x = MyModel.trythis(self,x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
def model(self):
x = tf.keras.layers.Input(shape=(32, 32, 3))
return Model(inputs=[x], outputs=self.call(x))
model = MyModel()
model_functional = model.model()
#model.build((None, 32,32,3))
model_functional.summary()
Summary is as follows
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_5 (InputLayer) [(None, 32, 32, 3)] 0
_________________________________________________________________
conv2d_5 (Conv2D) (None, 30, 30, 32) 896
_________________________________________________________________
batch_normalization (BatchNo (None, 30, 30, 32) 128
_________________________________________________________________
flatten_4 (Flatten) (None, 28800) 0
_________________________________________________________________
dense_8 (Dense) (None, 128) 3686528
_________________________________________________________________
dense_9 (Dense) (None, 10) 1290
=================================================================
Total params: 3,688,842
Trainable params: 3,688,778
Non-trainable params: 64
_________________________________________________________________
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With