When creating a model using Keras subclassing API, we write a custom model class and define a function named call(self, x)
(mostly to write the forward pass) which expects an input. However, this method is never called and instead of passing the input to call
, it is passed to the object of this class as model(images)
.
How are we able to call this model
object and pass values when we haven't implemented Python special method, __call__
in the class
class MyModel(Model):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = Conv2D(32, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(10, activation='softmax')
def call(self, x):
x = self.conv1(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
# Create an instance of the model
model = MyModel()
Use tf.GradientTape to train the model:
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images)
loss = loss_object(labels, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
Shouldn't the input be passed like below:
model = MyModel()
model.call(images)
Actually __call__
method is implemented in the Layer
class, which is inherited by Network
class, which is inherited by Model
class:
class Layer(module.Module):
def __call__(self, inputs, *args, **kwargs):
class Network(base_layer.Layer):
class Model(network.Network):
So MyClass
will inherit this __call__
method.
Additional info:
So actually what we do is overriding the inherited call
method, which new call
method will be then called from the inherited __call__
method. That is why we don't need to do a model.call()
.
So when we call our model instance, it's inherited __call__
method will be executed automatically, which calls our own call
method.
Occam's razor says that the __call__
method is implemented in the Model
class, so your subclass will inherit this method, which is why the call works. The __call__
in the Model
class just forwards parameters to your call
method and does some bookkeeping.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With