Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why in Keras subclassing API, the call method is never called and as an alternative the input is passed by calling the object of this class?

When creating a model using Keras subclassing API, we write a custom model class and define a function named call(self, x)(mostly to write the forward pass) which expects an input. However, this method is never called and instead of passing the input to call, it is passed to the object of this class as model(images).

How are we able to call this model object and pass values when we haven't implemented Python special method, __call__ in the class

class MyModel(Model):
  def __init__(self):
    super(MyModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10, activation='softmax')

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)

# Create an instance of the model
model = MyModel()

Use tf.GradientTape to train the model:

@tf.function
def train_step(images, labels):
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_object(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_loss(loss)
  train_accuracy(labels, predictions)

Shouldn't the input be passed like below:

model = MyModel()
model.call(images)
like image 653
Atul Avatar asked Dec 24 '19 20:12

Atul


2 Answers

Actually __call__ method is implemented in the Layer class, which is inherited by Network class, which is inherited by Model class:

class Layer(module.Module):
    def __call__(self, inputs, *args, **kwargs):

class Network(base_layer.Layer):

class Model(network.Network):

So MyClass will inherit this __call__ method.

Additional info:

So actually what we do is overriding the inherited call method, which new call method will be then called from the inherited __call__ method. That is why we don't need to do a model.call(). So when we call our model instance, it's inherited __call__ method will be executed automatically, which calls our own call method.

like image 143
Geeocode Avatar answered Oct 03 '22 22:10

Geeocode


Occam's razor says that the __call__ method is implemented in the Model class, so your subclass will inherit this method, which is why the call works. The __call__ in the Model class just forwards parameters to your call method and does some bookkeeping.

like image 23
Dr. Snoopy Avatar answered Oct 03 '22 23:10

Dr. Snoopy