I am trying to implement a simple model subclassing inspired by the VGG network.
So here is the code:
class ConvMax(tf.keras.Model):
def __init__(self, filters=4, kernel_size=3, pool_size=2, activation='relu'):
super(ConvMax, self).__init__()
self.conv = tf.keras.layers.Conv2D(filters, kernel_size, padding='same', activation=activation)
self.maxpool = tf.keras.layers.MaxPool2D((pool_size, pool_size))
def call(self, input_tensor):
x = self.conv(input_tensor)
x = self.maxpool(x)
return x
class RepeatedConvMax(tf.keras.Model):
def __init__(self, repetitions=4, filters=4, kernel_size=3, pool_size=2, activation='relu', **kwargs):
super(RepeatedConvMax, self).__init__(**kwargs)
self.repetitions = repetitions
self.filters = filters
self.kernel_size = kernel_size
self.pool_size = pool_size
self.activation = activation
# Define a repeated ConvMax
for i in range(self.repetitions):
# Define a ConvMax layer, specifying filters, kernel_size, pool_size.
vars(self)[f'convMax_{i}'] = ConvMax(self.filters, self.kernel_size, self.pool_size, self.activation)
def call(self, input_tensor):
# Connect the first layer
x = vars(self)['convMax_0'](input_tensor)
# Connect the existing layers
for i in range(1, self.repetitions):
x = vars(self)[f'convMax_{i}'](x)
# return the last layer
return x
But when I am trying to build the network to see the summaries, here is what I found:
model_input = tf.keras.layers.Input(shape=(64,64,3,), name="input_layer")
x = RepeatedConvMax()(model_input)
model = tf.keras.Model(inputs=model_input, outputs=x)
Model: "model_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) [(None, 64, 64, 3)] 0
_________________________________________________________________
repeated_conv_max (RepeatedC (None, 4, 4, 4) 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
The total params are zero
However, when I try:
model_input = tf.keras.layers.Input(shape=(64,64,3,), name="input_layer")
x = ConvMax()(model_input)
x = ConvMax()(x)
x = ConvMax()(x)
x = ConvMax()(x)
model = tf.keras.Model(inputs=model_input, outputs=x)
model.summary()
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_layer (InputLayer) [(None, 64, 64, 3)] 0
_________________________________________________________________
conv_max (ConvMax) (None, 32, 32, 4) 112
_________________________________________________________________
conv_max_1 (ConvMax) (None, 16, 16, 4) 148
_________________________________________________________________
conv_max_2 (ConvMax) (None, 8, 8, 4) 148
_________________________________________________________________
conv_max_3 (ConvMax) (None, 4, 4, 4) 148
=================================================================
Total params: 556
Trainable params: 556
Non-trainable params: 0
_________________________________________________________________
It shows correct the total params.
Do you know what is the problem? Why on the two-level subclassing, the parameter is 0? Will it affect the training?
Thank you...
The problem is not with keras but in the way you are initializing the layers in RepeatedConvMax.
TLDR: don't use vars to dinamically instantiate and retrieve attributes, instead use setattr and getattr
To solve the problem, you simply have to replace vars[] with setattr and getattr. From my (very limited, I actually found this out right now while looking for a solution) understanding, when you call vars you are working on a copy of the dictionary representing your object. When you create attributes dynamically in this way, Keras is not able to add the weights to the model (why is that, I don't yet know, but I will find out and update the answer when I do).
If you define your class like this, everything works as expected:
class RepeatedConvMax(tf.keras.Model):
def __init__(self, repetitions=4, filters=4, kernel_size=3, pool_size=2, activation='relu', **kwargs):
super(RepeatedConvMax, self).__init__(**kwargs)
self.repetitions = repetitions
self.filters = filters
self.kernel_size = kernel_size
self.pool_size = pool_size
self.activation = activation
# Define a repeated ConvMax
for i in range(self.repetitions):
# Define a ConvMax layer, specifying filters, kernel_size, pool_size.
setattr(self, f'convMax_{i}', ConvMax(self.filters, self.kernel_size, self.pool_size, self.activation))
def call(self, input_tensor, training=None, mask=None):
# Connect the first layer
x = getattr(self, 'convMax_0')(input_tensor)
# Connect the existing layers
for i in range(1, self.repetitions):
print(f"Layer {i}")
x = getattr(self, f'convMax_{i}')(x)
print(x)
# return the last layer
return x
Don't add your layers using vars, loop over the amount of layers you want and add them to a tf.keras.Sequential object, and do your forward pass through that.
Refactored Class:
class RefactoredRepeatedConvMax(tf.keras.models.Model):
def __init__(self,
repetitions=4,
filters=4,
kernel_size=3,
pool_size=2,
activation="relu"):
super().__init__()
self.repetitions = repetitions
self.filters = filters
self.kernel_size = kernel_size
self.pool_size = pool_size
self.activation = activation
self.conv_layers = tf.keras.Sequential()
for _ in tf.range(self.repetitions):
self.conv_layers.add(ConvMax(
self.filters,
self.kernel_size,
self.pool_size,
self.activation))
def call(self, x):
return self.conv_layers(x)
Model:
model_input = tf.keras.layers.Input(shape=(64, 64, 3), name="input_layer")
x = RefactoredRepeatedConvMax()(model_input)
model = tf.keras.Model(inputs=model_input, outputs=x)
model.summary()
# Model: "model"
# _________________________________________________________________
# Layer (type) Output Shape Param #
# =================================================================
# input_layer (InputLayer) [(None, 64, 64, 3)] 0
# _________________________________________________________________
# refactored_repeated_conv_max (None, 4, 4, 4) 556
# =================================================================
# Total params: 556
# Trainable params: 556
# Non-trainable params: 0
# _________________________________________________________________
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With