Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How does Tensorflow build() work from tf.keras.layers.Layer

I was wondering if anyone knew how the build() function works from the tf.keras.layers.Layer class under the hood. According to the documentation:

build is called when you know the shapes of the input tensors and can do the rest of the initialization

so to me it seems like the class is behaving similar to this:

class MyDenseLayer:
  def __init__(self, num_outputs):
    self.num_outputs = num_outputs

  def build(self, input_shape):
    self.kernel = self.add_weight("kernel",
                                  shape=[int(input_shape[-1]), self.num_outputs])

  def __call__(self, input):
    self.build(input.shape) ## build is called here when input shape is known
    return tf.matmul(input, self.kernel)

I can't imagine build() would be called for ever __call__, but it is the only place where the input is passed in. Does anyone know how exactly this works under the hood?

like image 617
Jamie Dimon Avatar asked Aug 12 '20 19:08

Jamie Dimon


1 Answers

The Layer.build() method is typically used to instantiate the weights of the layer. See the source code for tf.keras.layers.Dense for an example, and note that the weight and bias tensors are created in that function. The Layer.build() method takes an input_shape argument, and the shape of the weights and biases often depend on the shape of the input.

The Layer.call() method, on the other hand, implements the forward-pass of the layer. You do not want to overwrite __call__, because that is implemented in the base class tf.keras.layers.Layer. In a custom layer, you should implement call().

Layer.call() does not call Layer.build(). However, Layer().__call__() does call it if the layer has not been built yet (source), and that will set an attribute self.built = True to prevent Layer.build() from being called again. In other words, Layer.__call__() only calls Layer.build() the first time it is called.

like image 187
jakub Avatar answered Nov 18 '22 13:11

jakub