I want to implement a Keras custom layer without any input, just trainable weights.
Here is the code so far:
class Simple(Layer):
    def __init__(self, output_dim, **kwargs):
       self.output_dim = output_dim
       super(Simple, self).__init__(**kwargs)
    def build(self):
       self.kernel = self.add_weight(name='kernel', shape=self.output_dim, initializer='uniform', trainable=True)
       super(Simple, self).build()  
    def call(self):
       return self.kernel
    def compute_output_shape(self):
       return self.output_dim
X = Simple((1, 784))()
I am getting an error message:
__call__() missing 1 required positional argument: 'inputs'
Is there a workaround for building a custom layer without inputs in Keras?
If you are building a new model architecture using existing keras/tf layers then build a custom model. If you are implementing your own custom tensor operations with in a layer, then build a custom layer.
Implementing custom layers Layer class and implementing: __init__ , where you can do all input-independent initialization. build , where you know the shapes of the input tensors and can do the rest of the initialization. call , where you do the forward computation.
In Model Sub-Classing there are two most important functions __init__ and call. Basically, we will define all the trainable tf. keras layers or custom implemented layers inside the __init__ method and call those layers based on our network design inside the call method which is used to perform a forward propagation.
You can do the following,
from tensorflow.keras.layers import Layer
class Simple(Layer):
    def __init__(self, output_dim, **kwargs):
       self.output_dim = output_dim
       super(Simple, self).__init__(**kwargs)
    def build(self, input_shapes):
       self.kernel = self.add_weight(name='kernel', shape=self.output_dim, initializer='uniform', trainable=True)
       super(Simple, self).build(input_shapes)  
    def call(self, inputs):
       return self.kernel
    def compute_output_shape(self):
       return self.output_dim
X = Simple((1, 784))([])
print(X.shape)
Which produces
>>> (1, 784)
                        If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With