I am trying to apply feature-wise scaling and shifting (also called an affine transformation - the idea is described in the Nomenclature section of this distill article) to a Keras tensor (with TF backend).
The tensor I would like to transform, call it X
, is the output of a convolutional layer, and has shape (B,H,W,F)
, representing (batch size, height, width, number of feature maps).
The parameters of my transformation are two (B,F)
-dimensional tensors, beta
and gamma
.
I want X * gamma + beta
, or to be more specific,
for b in range(B):
for f in range(F):
X[b,:,:,f] = X[b,:,:,f] * gamma[b,f] + beta[b,f]
However, neither of these two ways of doing it works in Keras. The second, with element-wise assignment, fails due to
TypeError: 'Tensor' object does not support item assignment
and should be fairly inefficient as well.
How the first fails is more cryptic to me, but my guess is that it is an issue with broadcasting. In the full code + traceback below, you can see my attempt.
Two things to note are that the error only happens at training time (and not when compiling), and that the 'transform_vars'
input is seemingly never used, at least according to the model summary.
Any ideas on how to implement this?
import numpy as np
import keras as ks
import keras.backend as K
print(ks.__version__)
# Load example data (here MNIST)
from keras.datasets import mnist
(x_img_train, y_train), _ = mnist.load_data()
x_img_train = np.expand_dims(x_img_train,-1)
# Generator some data to use for transformations
n_transform_vars = 10
x_transform_train = np.random.randn(y_train.shape[0], n_transform_vars)
# Inputs
input_transform = ks.layers.Input(x_transform_train.shape[1:], name='transform_vars')
input_img = ks.layers.Input(x_img_train.shape[1:], name='imgs')
# Number of feature maps
n_features = 32
# Create network that calculates the transformations
tns_transform = ks.layers.Dense(2 * n_features)(input_transform)
tns_transform = ks.layers.Reshape((2, 32))(tns_transform)
# Do a convolution
tns_conv = ks.layers.Conv2D(filters=n_features, kernel_size=3, padding='same')(input_img)
# Apply batch norm
bn = ks.layers.BatchNormalization()
# Freeze the weights of the batch norm, as they are going to be overwritten
bn.trainable = False
# Apply
tns_conv = bn(tns_conv)
# Attempt to apply the affine transformation
def scale_and_shift(x):
return x * tns_transform[:,0] + tns_transform[:,1]
tns_conv = ks.layers.Lambda(scale_and_shift, name='affine_transform')(tns_conv)
tns_conv = ks.layers.Flatten()(tns_conv)
output = ks.layers.Dense(1)(tns_conv)
model = ks.models.Model(inputs=[input_img, input_transform], outputs=output)
model.compile(loss='mse', optimizer='Adam')
model.summary()
model.fit([x_img_train, x_transform_train], y_train, batch_size=8)
This results in
2.2.4
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
imgs (InputLayer) (None, 28, 28, 1) 0
_________________________________________________________________
conv2d_25 (Conv2D) (None, 28, 28, 32) 320
_________________________________________________________________
batch_normalization_22 (Batc (None, 28, 28, 32) 128
_________________________________________________________________
affine_transform (Lambda) (None, 28, 28, 32) 0
_________________________________________________________________
flatten_6 (Flatten) (None, 25088) 0
_________________________________________________________________
dense_33 (Dense) (None, 1) 25089
=================================================================
Total params: 25,537
Trainable params: 25,409
Non-trainable params: 128
_________________________________________________________________
Epoch 1/1
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-35-14724d9432ef> in <module>
49 model.summary()
50
---> 51 model.fit([x_img_train, x_transform_train], y_train, batch_size=8)
~/miniconda3/envs/py3/lib/python3.6/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
1037 initial_epoch=initial_epoch,
1038 steps_per_epoch=steps_per_epoch,
-> 1039 validation_steps=validation_steps)
1040
1041 def evaluate(self, x=None, y=None,
~/miniconda3/envs/py3/lib/python3.6/site-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
197 ins_batch[i] = ins_batch[i].toarray()
198
--> 199 outs = f(ins_batch)
200 outs = to_list(outs)
201 for l, o in zip(out_labels, outs):
~/miniconda3/envs/py3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
2713 return self._legacy_call(inputs)
2714
-> 2715 return self._call(inputs)
2716 else:
2717 if py_any(is_tensor(x) for x in inputs):
~/miniconda3/envs/py3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
2673 fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
2674 else:
-> 2675 fetched = self._callable_fn(*array_vals)
2676 return fetched[:len(self.outputs)]
2677
~/miniconda3/envs/py3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
1437 ret = tf_session.TF_SessionRunCallable(
1438 self._session._session, self._handle, args, status,
-> 1439 run_metadata_ptr)
1440 if run_metadata:
1441 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~/miniconda3/envs/py3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
526 None, None,
527 compat.as_text(c_api.TF_Message(self.status.status)),
--> 528 c_api.TF_GetCode(self.status.status))
529 # Delete the underlying status object from memory otherwise it stays alive
530 # as there is a reference to status from this from the traceback due to
InvalidArgumentError: Incompatible shapes: [8,28,28,32] vs. [8,32]
[[{{node training_5/Adam/gradients/affine_transform_18/mul_grad/BroadcastGradientArgs}} = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@training_5/Adam/gradients/batch_normalization_22/cond/Merge_grad/cond_grad"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training_5/Adam/gradients/affine_transform_18/mul_grad/Shape, training_5/Adam/gradients/affine_transform_18/mul_grad/Shape_1)]]
I managed to implement the affine transformation as a custom layer (here called a FiLM layer, as in the litterature):
class FiLM(ks.layers.Layer):
def __init__(self, widths=[64,64], activation='leakyrelu',
initialization='glorot_uniform', **kwargs):
self.widths = widths
self.activation = activation
self.initialization = initialization
super(FiLM, self).__init__(**kwargs)
def build(self, input_shape):
assert isinstance(input_shape, list)
feature_map_shape, FiLM_vars_shape = input_shape
self.n_feature_maps = feature_map_shape[-1]
self.height = feature_map_shape[1]
self.width = feature_map_shape[2]
# Collect trainable weights
trainable_weights = []
# Create weights for hidden layers
self.hidden_dense_layers = []
for i,width in enumerate(self.widths):
dense = ks.layers.Dense(width,
kernel_initializer=self.initialization,
name=f'FiLM_dense_{i}')
if i==0:
build_shape = FiLM_vars_shape[:2]
else:
build_shape = (None,self.widths[i-1])
dense.build(build_shape)
trainable_weights += dense.trainable_weights
self.hidden_dense_layers.append(dense)
# Create weights for output layer
self.output_dense = ks.layers.Dense(2 * self.n_feature_maps, # assumes channel_last
kernel_initializer=self.initialization,
name=f'FiLM_dense_output')
self.output_dense.build((None,self.widths[-1]))
trainable_weights += self.output_dense.trainable_weights
# Pass on all collected trainable weights
self._trainable_weights = trainable_weights
super(FiLM, self).build(input_shape)
def call(self, x):
assert isinstance(x, list)
conv_output, FiLM_vars = x
# Generate FiLM outputs
tns = FiLM_vars
for i in range(len(self.widths)):
tns = self.hidden_dense_layers[i](tns)
tns = get_activation(activation=self.activation)(tns)
FiLM_output = self.output_dense(tns)
# Duplicate in order to apply to entire feature maps
# Taken from https://github.com/GuessWhatGame/neural_toolbox/blob/master/film_layer.py
FiLM_output = K.expand_dims(FiLM_output, axis=[1])
FiLM_output = K.expand_dims(FiLM_output, axis=[1])
FiLM_output = K.tile(FiLM_output, [1, self.height, self.width, 1])
# Split into gammas and betas
gammas = FiLM_output[:, :, :, :self.n_feature_maps]
betas = FiLM_output[:, :, :, self.n_feature_maps:]
# Apply affine transformation
return (1 + gammas) * conv_output + betas
def compute_output_shape(self, input_shape):
assert isinstance(input_shape, list)
return input_shape[0]
It depends on the function get_activation
, which essentially just returns a Keras activation instance. You can see the full working example below.
Note that this layer does the processing of the transform_vars
in the layer itself. If you want to process these variables in another network, see the edit below.
import numpy as np
import keras as ks
import keras.backend as K
def get_activation(tns=None, activation='relu'):
'''
Adds an activation layer to a graph.
Args :
tns :
*Keras tensor or None*
Input tensor. If not None, then the graph will be connected through
it, and a tensor will be returned. If None, the activation layer
will be returned.
activation :
*str, optional (default='relu')*
The name of an activation function.
One of 'relu', 'leakyrelu', 'prelu', 'elu', 'mrelu' or 'swish',
or anything that Keras will recognize as an activation function
name.
Returns :
*Keras tensor or layer instance* (see tns argument)
'''
if activation == 'relu':
act = ks.layers.ReLU()
elif activation == 'leakyrelu':
act = ks.layers.LeakyReLU()
elif activation == 'prelu':
act = ks.layers.PReLU()
elif activation == 'elu':
act = ks.layers.ELU()
elif activation == 'swish':
def swish(x):
return K.sigmoid(x) * x
act = ks.layers.Activation(swish)
elif activation == 'mrelu':
def mrelu(x):
return K.minimum(K.maximum(1-x, 0), K.maximum(1+x, 0))
act = ks.layers.Activation(mrelu)
elif activation == 'gaussian':
def gaussian(x):
return K.exp(-x**2)
act = ks.layers.Activation(gaussian)
elif activation == 'flipped_gaussian':
def flipped_gaussian(x):
return 1 - K.exp(-x**2)
act = ks.layers.Activation(flipped_gaussian)
else:
act = ks.layers.Activation(activation)
if tns is not None:
return act(tns)
else:
return act
class FiLM(ks.layers.Layer):
def __init__(self, widths=[64,64], activation='leakyrelu',
initialization='glorot_uniform', **kwargs):
self.widths = widths
self.activation = activation
self.initialization = initialization
super(FiLM, self).__init__(**kwargs)
def build(self, input_shape):
assert isinstance(input_shape, list)
feature_map_shape, FiLM_vars_shape = input_shape
self.n_feature_maps = feature_map_shape[-1]
self.height = feature_map_shape[1]
self.width = feature_map_shape[2]
# Collect trainable weights
trainable_weights = []
# Create weights for hidden layers
self.hidden_dense_layers = []
for i,width in enumerate(self.widths):
dense = ks.layers.Dense(width,
kernel_initializer=self.initialization,
name=f'FiLM_dense_{i}')
if i==0:
build_shape = FiLM_vars_shape[:2]
else:
build_shape = (None,self.widths[i-1])
dense.build(build_shape)
trainable_weights += dense.trainable_weights
self.hidden_dense_layers.append(dense)
# Create weights for output layer
self.output_dense = ks.layers.Dense(2 * self.n_feature_maps, # assumes channel_last
kernel_initializer=self.initialization,
name=f'FiLM_dense_output')
self.output_dense.build((None,self.widths[-1]))
trainable_weights += self.output_dense.trainable_weights
# Pass on all collected trainable weights
self._trainable_weights = trainable_weights
super(FiLM, self).build(input_shape)
def call(self, x):
assert isinstance(x, list)
conv_output, FiLM_vars = x
# Generate FiLM outputs
tns = FiLM_vars
for i in range(len(self.widths)):
tns = self.hidden_dense_layers[i](tns)
tns = get_activation(activation=self.activation)(tns)
FiLM_output = self.output_dense(tns)
# Duplicate in order to apply to entire feature maps
# Taken from https://github.com/GuessWhatGame/neural_toolbox/blob/master/film_layer.py
FiLM_output = K.expand_dims(FiLM_output, axis=[1])
FiLM_output = K.expand_dims(FiLM_output, axis=[1])
FiLM_output = K.tile(FiLM_output, [1, self.height, self.width, 1])
# Split into gammas and betas
gammas = FiLM_output[:, :, :, :self.n_feature_maps]
betas = FiLM_output[:, :, :, self.n_feature_maps:]
# Apply affine transformation
return (1 + gammas) * conv_output + betas
def compute_output_shape(self, input_shape):
assert isinstance(input_shape, list)
return input_shape[0]
print(ks.__version__)
# Load example data (here MNIST)
from keras.datasets import mnist
(x_img_train, y_train), _ = mnist.load_data()
x_img_train = np.expand_dims(x_img_train,-1)
# Generator some data to use for transformations
n_transform_vars = 10
x_transform_train = np.random.randn(y_train.shape[0], n_transform_vars)
# Inputs
input_transform = ks.layers.Input(x_transform_train.shape[1:], name='transform_vars')
input_img = ks.layers.Input(x_img_train.shape[1:], name='imgs')
# Number of feature maps
n_features = 32
# Do a convolution
tns = ks.layers.Conv2D(filters=n_features, kernel_size=3, padding='same')(input_img)
# Apply batch norm
bn = ks.layers.BatchNormalization()
# Freeze the weights of the batch norm, as they are going to be overwritten
bn.trainable = False
# Apply batch norm
tns = bn(tns)
# Apply FiLM layer
tns = FiLM(widths=[12,24], name='FiLM_layer')([tns, input_transform])
# Make 1D output
tns = ks.layers.Flatten()(tns)
output = ks.layers.Dense(1)(tns)
# Compile and plot
model = ks.models.Model(inputs=[input_img, input_transform], outputs=output)
model.compile(loss='mse', optimizer='Adam')
model.summary()
ks.utils.plot_model(model, './model_with_FiLM.png')
# Train
model.fit([x_img_train, x_transform_train], y_train, batch_size=8)
EDIT: Here is the "non-active" FiLM layer, which takes in the predictions of another network (the FiLM generator) and uses them as gammas and betas.
This way of doing it is equivalent, but simpler, as you keep all the trainable weights in the FiLM generator, and therefore ensures weight-sharing.
class FiLM(ks.layers.Layer):
def __init__(self, **kwargs):
super(FiLM, self).__init__(**kwargs)
def build(self, input_shape):
assert isinstance(input_shape, list)
feature_map_shape, FiLM_tns_shape = input_shape
self.height = feature_map_shape[1]
self.width = feature_map_shape[2]
self.n_feature_maps = feature_map_shape[-1]
assert(int(2 * self.n_feature_maps)==FiLM_tns_shape[1])
super(FiLM, self).build(input_shape)
def call(self, x):
assert isinstance(x, list)
conv_output, FiLM_tns = x
# Duplicate in order to apply to entire feature maps
# Taken from https://github.com/GuessWhatGame/neural_toolbox/blob/master/film_layer.py
FiLM_tns = K.expand_dims(FiLM_tns, axis=[1])
FiLM_tns = K.expand_dims(FiLM_tns, axis=[1])
FiLM_tns = K.tile(FiLM_tns, [1, self.height, self.width, 1])
# Split into gammas and betas
gammas = FiLM_tns[:, :, :, :self.n_feature_maps]
betas = FiLM_tns[:, :, :, self.n_feature_maps:]
# Apply affine transformation
return (1 + gammas) * conv_output + betas
def compute_output_shape(self, input_shape):
assert isinstance(input_shape, list)
return input_shape[0]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With