I have to design a neural network that takes two inputs X_1 and X_2. The layer transforms them to fixed-size vectors(10D) and then sums them in the following manner
class my_lyr(tf.keras.layers.Layer):
def __init__(self):
pass
def call(self, X_1, X_2):
return X_1 @ self.w1 + X_2 @ self.w2
However, I need to know the input shape of X_1 and X_2 before I initialize w1 and w2.
I'm not sure how can I declare w2 in build.
def build(self, input_shape):
self.w1 = self.add_weight('w1', shape=[input_shape[-1],10])
// self.w2 = ?????
I want to know how to build methods are usually written in such cases.
If you've two input of such layer, then you can simply initialize your weights something like as follows
import tensorflow as tf
from tensorflow import keras
class Linear(keras.layers.Layer):
def __init__(self, units=32):
super(Linear, self).__init__()
self.units = units
def build(self, input_shape):
self.wa = self.add_weight(
shape=(input_shape[0][-1], self.units),
initializer="random_normal",
trainable=True,
)
self.wb = self.add_weight(
shape=(input_shape[1][-1], self.units),
initializer="random_normal",
trainable=True,
)
def call(self, inputs):
return tf.matmul(inputs[0], self.wa) + tf.matmul(inputs[1], self.wb)
Passing inputs
x = tf.random.normal(shape=(2,2))
linear_layer = Linear(32)
linear_layer([x, x])
<tf.Tensor: shape=(2, 32), dtype=float32, numpy=
array([[-0.08829461, -0.01605312, -0.04368614, -0.08116315, -0.01521384,
0.01132785, 0.10704445, -0.10873697, -0.0525714 , 0.07684848,
0.04586978, 0.01315852, 0.01369547, 0.07404792, 0.10313608,
-0.10851607, 0.04091477, -0.01723676, -0.0326797 , 0.03598418,
-0.11335816, -0.10044714, 0.13555384, 0.01689356, 0.02631954,
0.08226107, -0.08765724, -0.05981663, 0.00531629, 0.02930426,
0.04155847, 0.05339598],
[ 0.20617458, -0.05936547, 0.01735754, -0.06575315, 0.10090968,
-0.07796012, -0.1956767 , -0.03406558, 0.18604615, -0.03547171,
0.02784208, 0.0471364 , -0.10712875, -0.07869454, -0.19457275,
0.13593757, -0.14659101, 0.0384632 , 0.02344182, -0.03861775,
0.08948556, 0.09225713, -0.17395493, 0.10021958, -0.09210777,
-0.09865301, 0.2536609 , -0.02547608, 0.02885125, -0.01271547,
-0.10340843, -0.0338558 ]], dtype=float32)>
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With