Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Feature-wise scaling and shifting (FiLM layer) in Keras

I am trying to apply feature-wise scaling and shifting (also called an affine transformation - the idea is described in the Nomenclature section of this distill article) to a Keras tensor (with TF backend).

The tensor I would like to transform, call it X, is the output of a convolutional layer, and has shape (B,H,W,F), representing (batch size, height, width, number of feature maps).

The parameters of my transformation are two (B,F)-dimensional tensors, beta and gamma.

I want X * gamma + beta, or to be more specific,

for b in range(B):
    for f in range(F):
        X[b,:,:,f] = X[b,:,:,f] * gamma[b,f] + beta[b,f]

However, neither of these two ways of doing it works in Keras. The second, with element-wise assignment, fails due to

TypeError: 'Tensor' object does not support item assignment

and should be fairly inefficient as well.

How the first fails is more cryptic to me, but my guess is that it is an issue with broadcasting. In the full code + traceback below, you can see my attempt.

Two things to note are that the error only happens at training time (and not when compiling), and that the 'transform_vars' input is seemingly never used, at least according to the model summary.

Any ideas on how to implement this?

import numpy as np
import keras as ks
import keras.backend as K

print(ks.__version__)

# Load example data (here MNIST)
from keras.datasets import mnist
(x_img_train, y_train), _ = mnist.load_data()
x_img_train = np.expand_dims(x_img_train,-1)

# Generator some data to use for transformations
n_transform_vars = 10
x_transform_train = np.random.randn(y_train.shape[0], n_transform_vars)

# Inputs
input_transform = ks.layers.Input(x_transform_train.shape[1:], name='transform_vars')
input_img = ks.layers.Input(x_img_train.shape[1:], name='imgs')

# Number of feature maps
n_features = 32

# Create network that calculates the transformations
tns_transform = ks.layers.Dense(2 * n_features)(input_transform)
tns_transform = ks.layers.Reshape((2, 32))(tns_transform)

# Do a convolution
tns_conv = ks.layers.Conv2D(filters=n_features, kernel_size=3, padding='same')(input_img)

# Apply batch norm
bn = ks.layers.BatchNormalization()

# Freeze the weights of the batch norm, as they are going to be overwritten
bn.trainable = False

# Apply
tns_conv = bn(tns_conv)

# Attempt to apply the affine transformation
def scale_and_shift(x):
    return x * tns_transform[:,0] + tns_transform[:,1]

tns_conv = ks.layers.Lambda(scale_and_shift, name='affine_transform')(tns_conv)
tns_conv = ks.layers.Flatten()(tns_conv)

output = ks.layers.Dense(1)(tns_conv)

model = ks.models.Model(inputs=[input_img, input_transform], outputs=output)
model.compile(loss='mse', optimizer='Adam')
model.summary()

model.fit([x_img_train, x_transform_train], y_train, batch_size=8)

This results in

2.2.4
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
imgs (InputLayer)            (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d_25 (Conv2D)           (None, 28, 28, 32)        320       
_________________________________________________________________
batch_normalization_22 (Batc (None, 28, 28, 32)        128       
_________________________________________________________________
affine_transform (Lambda)    (None, 28, 28, 32)        0         
_________________________________________________________________
flatten_6 (Flatten)          (None, 25088)             0         
_________________________________________________________________
dense_33 (Dense)             (None, 1)                 25089     
=================================================================
Total params: 25,537
Trainable params: 25,409
Non-trainable params: 128
_________________________________________________________________
Epoch 1/1
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-35-14724d9432ef> in <module>
     49 model.summary()
     50 
---> 51 model.fit([x_img_train, x_transform_train], y_train, batch_size=8)

~/miniconda3/envs/py3/lib/python3.6/site-packages/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, **kwargs)
   1037                                         initial_epoch=initial_epoch,
   1038                                         steps_per_epoch=steps_per_epoch,
-> 1039                                         validation_steps=validation_steps)
   1040 
   1041     def evaluate(self, x=None, y=None,

~/miniconda3/envs/py3/lib/python3.6/site-packages/keras/engine/training_arrays.py in fit_loop(model, f, ins, out_labels, batch_size, epochs, verbose, callbacks, val_f, val_ins, shuffle, callback_metrics, initial_epoch, steps_per_epoch, validation_steps)
    197                     ins_batch[i] = ins_batch[i].toarray()
    198 
--> 199                 outs = f(ins_batch)
    200                 outs = to_list(outs)
    201                 for l, o in zip(out_labels, outs):

~/miniconda3/envs/py3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2713                 return self._legacy_call(inputs)
   2714 
-> 2715             return self._call(inputs)
   2716         else:
   2717             if py_any(is_tensor(x) for x in inputs):

~/miniconda3/envs/py3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in _call(self, inputs)
   2673             fetched = self._callable_fn(*array_vals, run_metadata=self.run_metadata)
   2674         else:
-> 2675             fetched = self._callable_fn(*array_vals)
   2676         return fetched[:len(self.outputs)]
   2677 

~/miniconda3/envs/py3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1437           ret = tf_session.TF_SessionRunCallable(
   1438               self._session._session, self._handle, args, status,
-> 1439               run_metadata_ptr)
   1440         if run_metadata:
   1441           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/miniconda3/envs/py3/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py in __exit__(self, type_arg, value_arg, traceback_arg)
    526             None, None,
    527             compat.as_text(c_api.TF_Message(self.status.status)),
--> 528             c_api.TF_GetCode(self.status.status))
    529     # Delete the underlying status object from memory otherwise it stays alive
    530     # as there is a reference to status from this from the traceback due to

InvalidArgumentError: Incompatible shapes: [8,28,28,32] vs. [8,32]
     [[{{node training_5/Adam/gradients/affine_transform_18/mul_grad/BroadcastGradientArgs}} = BroadcastGradientArgs[T=DT_INT32, _class=["loc:@training_5/Adam/gradients/batch_normalization_22/cond/Merge_grad/cond_grad"], _device="/job:localhost/replica:0/task:0/device:CPU:0"](training_5/Adam/gradients/affine_transform_18/mul_grad/Shape, training_5/Adam/gradients/affine_transform_18/mul_grad/Shape_1)]]
like image 249
Bobson Dugnutt Avatar asked Mar 17 '19 18:03

Bobson Dugnutt


1 Answers

I managed to implement the affine transformation as a custom layer (here called a FiLM layer, as in the litterature):

class FiLM(ks.layers.Layer):

    def __init__(self, widths=[64,64], activation='leakyrelu',
                 initialization='glorot_uniform', **kwargs):
        self.widths = widths
        self.activation = activation
        self.initialization = initialization
        super(FiLM, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        feature_map_shape, FiLM_vars_shape = input_shape
        self.n_feature_maps = feature_map_shape[-1]
        self.height = feature_map_shape[1]
        self.width = feature_map_shape[2]

        # Collect trainable weights
        trainable_weights = []

        # Create weights for hidden layers
        self.hidden_dense_layers = []
        for i,width in enumerate(self.widths):
            dense = ks.layers.Dense(width,
                                    kernel_initializer=self.initialization,
                                    name=f'FiLM_dense_{i}')
            if i==0:
                build_shape = FiLM_vars_shape[:2]
            else:
                build_shape = (None,self.widths[i-1])
            dense.build(build_shape)
            trainable_weights += dense.trainable_weights
            self.hidden_dense_layers.append(dense)

        # Create weights for output layer
        self.output_dense = ks.layers.Dense(2 * self.n_feature_maps, # assumes channel_last
                                            kernel_initializer=self.initialization,
                                            name=f'FiLM_dense_output')
        self.output_dense.build((None,self.widths[-1]))
        trainable_weights += self.output_dense.trainable_weights

        # Pass on all collected trainable weights
        self._trainable_weights = trainable_weights

        super(FiLM, self).build(input_shape)

    def call(self, x):
        assert isinstance(x, list)
        conv_output, FiLM_vars = x

        # Generate FiLM outputs
        tns = FiLM_vars
        for i in range(len(self.widths)):
            tns = self.hidden_dense_layers[i](tns)
            tns = get_activation(activation=self.activation)(tns)
        FiLM_output = self.output_dense(tns)

        # Duplicate in order to apply to entire feature maps
        # Taken from https://github.com/GuessWhatGame/neural_toolbox/blob/master/film_layer.py
        FiLM_output = K.expand_dims(FiLM_output, axis=[1])
        FiLM_output = K.expand_dims(FiLM_output, axis=[1])
        FiLM_output = K.tile(FiLM_output, [1, self.height, self.width, 1])

        # Split into gammas and betas
        gammas = FiLM_output[:, :, :, :self.n_feature_maps]
        betas = FiLM_output[:, :, :, self.n_feature_maps:]

        # Apply affine transformation
        return (1 + gammas) * conv_output + betas

    def compute_output_shape(self, input_shape):
        assert isinstance(input_shape, list)
        return input_shape[0]

It depends on the function get_activation, which essentially just returns a Keras activation instance. You can see the full working example below.

Note that this layer does the processing of the transform_vars in the layer itself. If you want to process these variables in another network, see the edit below.

import numpy as np
import keras as ks
import keras.backend as K


def get_activation(tns=None, activation='relu'):
    '''
    Adds an activation layer to a graph.

    Args :
        tns :
            *Keras tensor or None*

            Input tensor. If not None, then the graph will be connected through
            it, and a tensor will be returned. If None, the activation layer
            will be returned.
        activation :
            *str, optional (default='relu')*

            The name of an activation function.
            One of 'relu', 'leakyrelu', 'prelu', 'elu', 'mrelu' or 'swish',
            or anything that Keras will recognize as an activation function
            name.

    Returns :
        *Keras tensor or layer instance* (see tns argument)
    '''

    if activation == 'relu':
        act = ks.layers.ReLU()

    elif activation == 'leakyrelu':
        act = ks.layers.LeakyReLU()

    elif activation == 'prelu':
        act = ks.layers.PReLU()

    elif activation == 'elu':
        act = ks.layers.ELU()

    elif activation == 'swish':
        def swish(x):
            return K.sigmoid(x) * x
        act = ks.layers.Activation(swish)

    elif activation == 'mrelu':
        def mrelu(x):
            return K.minimum(K.maximum(1-x, 0), K.maximum(1+x, 0))
        act = ks.layers.Activation(mrelu)

    elif activation == 'gaussian':
        def gaussian(x):
            return K.exp(-x**2)
        act = ks.layers.Activation(gaussian)

    elif activation == 'flipped_gaussian':
        def flipped_gaussian(x):
            return 1 - K.exp(-x**2)
        act = ks.layers.Activation(flipped_gaussian)

    else:
        act = ks.layers.Activation(activation)

    if tns is not None:
        return act(tns)
    else:
        return act


class FiLM(ks.layers.Layer):

    def __init__(self, widths=[64,64], activation='leakyrelu',
                 initialization='glorot_uniform', **kwargs):
        self.widths = widths
        self.activation = activation
        self.initialization = initialization
        super(FiLM, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        feature_map_shape, FiLM_vars_shape = input_shape
        self.n_feature_maps = feature_map_shape[-1]
        self.height = feature_map_shape[1]
        self.width = feature_map_shape[2]

        # Collect trainable weights
        trainable_weights = []

        # Create weights for hidden layers
        self.hidden_dense_layers = []
        for i,width in enumerate(self.widths):
            dense = ks.layers.Dense(width,
                                    kernel_initializer=self.initialization,
                                    name=f'FiLM_dense_{i}')
            if i==0:
                build_shape = FiLM_vars_shape[:2]
            else:
                build_shape = (None,self.widths[i-1])
            dense.build(build_shape)
            trainable_weights += dense.trainable_weights
            self.hidden_dense_layers.append(dense)

        # Create weights for output layer
        self.output_dense = ks.layers.Dense(2 * self.n_feature_maps, # assumes channel_last
                                            kernel_initializer=self.initialization,
                                            name=f'FiLM_dense_output')
        self.output_dense.build((None,self.widths[-1]))
        trainable_weights += self.output_dense.trainable_weights

        # Pass on all collected trainable weights
        self._trainable_weights = trainable_weights

        super(FiLM, self).build(input_shape)

    def call(self, x):
        assert isinstance(x, list)
        conv_output, FiLM_vars = x

        # Generate FiLM outputs
        tns = FiLM_vars
        for i in range(len(self.widths)):
            tns = self.hidden_dense_layers[i](tns)
            tns = get_activation(activation=self.activation)(tns)
        FiLM_output = self.output_dense(tns)

        # Duplicate in order to apply to entire feature maps
        # Taken from https://github.com/GuessWhatGame/neural_toolbox/blob/master/film_layer.py
        FiLM_output = K.expand_dims(FiLM_output, axis=[1])
        FiLM_output = K.expand_dims(FiLM_output, axis=[1])
        FiLM_output = K.tile(FiLM_output, [1, self.height, self.width, 1])

        # Split into gammas and betas
        gammas = FiLM_output[:, :, :, :self.n_feature_maps]
        betas = FiLM_output[:, :, :, self.n_feature_maps:]

        # Apply affine transformation
        return (1 + gammas) * conv_output + betas

    def compute_output_shape(self, input_shape):
        assert isinstance(input_shape, list)
        return input_shape[0]


print(ks.__version__)

# Load example data (here MNIST)
from keras.datasets import mnist
(x_img_train, y_train), _ = mnist.load_data()
x_img_train = np.expand_dims(x_img_train,-1)

# Generator some data to use for transformations
n_transform_vars = 10
x_transform_train = np.random.randn(y_train.shape[0], n_transform_vars)

# Inputs
input_transform = ks.layers.Input(x_transform_train.shape[1:], name='transform_vars')
input_img = ks.layers.Input(x_img_train.shape[1:], name='imgs')

# Number of feature maps
n_features = 32

# Do a convolution
tns = ks.layers.Conv2D(filters=n_features, kernel_size=3, padding='same')(input_img)

# Apply batch norm
bn = ks.layers.BatchNormalization()

# Freeze the weights of the batch norm, as they are going to be overwritten
bn.trainable = False

# Apply batch norm
tns = bn(tns)

# Apply FiLM layer
tns = FiLM(widths=[12,24], name='FiLM_layer')([tns, input_transform])

# Make 1D output
tns = ks.layers.Flatten()(tns)
output = ks.layers.Dense(1)(tns)

# Compile and plot
model = ks.models.Model(inputs=[input_img, input_transform], outputs=output)
model.compile(loss='mse', optimizer='Adam')
model.summary()
ks.utils.plot_model(model, './model_with_FiLM.png')

# Train
model.fit([x_img_train, x_transform_train], y_train, batch_size=8)

EDIT: Here is the "non-active" FiLM layer, which takes in the predictions of another network (the FiLM generator) and uses them as gammas and betas.

This way of doing it is equivalent, but simpler, as you keep all the trainable weights in the FiLM generator, and therefore ensures weight-sharing.

class FiLM(ks.layers.Layer):

    def __init__(self, **kwargs):
        super(FiLM, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        feature_map_shape, FiLM_tns_shape = input_shape
        self.height = feature_map_shape[1]
        self.width = feature_map_shape[2]
        self.n_feature_maps = feature_map_shape[-1]
        assert(int(2 * self.n_feature_maps)==FiLM_tns_shape[1])
        super(FiLM, self).build(input_shape)

    def call(self, x):
        assert isinstance(x, list)
        conv_output, FiLM_tns = x

        # Duplicate in order to apply to entire feature maps
        # Taken from https://github.com/GuessWhatGame/neural_toolbox/blob/master/film_layer.py
        FiLM_tns = K.expand_dims(FiLM_tns, axis=[1])
        FiLM_tns = K.expand_dims(FiLM_tns, axis=[1])
        FiLM_tns = K.tile(FiLM_tns, [1, self.height, self.width, 1])

        # Split into gammas and betas
        gammas = FiLM_tns[:, :, :, :self.n_feature_maps]
        betas = FiLM_tns[:, :, :, self.n_feature_maps:]

        # Apply affine transformation
        return (1 + gammas) * conv_output + betas

    def compute_output_shape(self, input_shape):
        assert isinstance(input_shape, list)
        return input_shape[0]
like image 127
Bobson Dugnutt Avatar answered Oct 19 '22 00:10

Bobson Dugnutt