I'm trying to use multiple inputs in custom layers in Tensorflow-Keras. Usage can be anything, right now it is defined as multiplying the mask with the image. I've search SO and the only answer I could find was for TF 1.x so it didn't do any good.
class mul(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# I've added pass because this is the simplest form I can come up with.
pass
def call(self, inputs):
# magic happens here and multiplications occur
return(Z)
EDIT: Since TensorFlow v2.3/2.4, the contract is to use a list of inputs to the call
method. For keras
(not tf.keras
) I think the answer below still applies.
Implementing multiple inputs is done in the call
method of your class, there are two alternatives:
List input, here the inputs
parameter is expected to be a list containing all the inputs, the advantage here is that it can be variable size. You can index the list, or unpack arguments using the =
operator:
def call(self, inputs):
Z = inputs[0] * inputs[1]
#Alternate
input1, input2 = inputs
Z = input1 * input2
return Z
Multiple input parameters in the call
method, works but then the number of parameters is fixed when the layer is defined:
def call(self, input1, input2):
Z = input1 * input2
return Z
Whatever method you choose to implement this depends if you need fixed size or variable sized number of arguments. Of course each method changes how the layer has to be called, either by passing a list of arguments, or by passing arguments one by one in the function call.
You can also use *args
in the first method to allow for a call
method with a variable number of arguments, but overall keras' own layers that take multiple inputs (like Concatenate
and Add
) are implemented using lists.
try in this way
class mul(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# I've added pass because this is the simplest form I can come up with.
pass
def call(self, inputs):
inp1, inp2 = inputs
Z = inp1*inp2
return Z
inp1 = Input((10))
inp2 = Input((10))
x = mul()([inp1,inp2])
x = Dense(1)(x)
model = Model([inp1,inp2],x)
model.summary()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With