Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Extract intermmediate variable from a custom Tensorflow/Keras layer during inference (TF 2.0)

A bit of background:

I've implemented an NLP classification model using mostly Keras functional model bits of Tensorflow 2.0. The model architecture is a pretty straightforward LSTM network with the addition of an Attention layer between the LSTM and the Dense output layer. The Attention layer comes from this Kaggle kernel (starting around line 51).

I wrapped the trained model in a simple Flask app and get reasonably accurate predictions. In addition to predicting a class for a specific input I also output the value of the attention weight vector "a" from the aforementioned Attention layer so I can visualize the weights applied to the input sequence.

My current method of extracting the attention weights variable works, but seems incredibly inefficient as I'm predicting the output class and then manually calculating the attention vector using an intermediate Keras model. In the Flask app, inference looks something like this:

# Load the trained model
model = tf.keras.models.load_model('saved_model.h5')

# Extract the trained weights and biases of the trained attention layer
attention_weights = model.get_layer('attention').get_weights()

# Create an intermediate model that outputs the activations of the LSTM layer
intermediate_model = tf.keras.Model(inputs=model.input, outputs=model.get_layer('bi-lstm').output)

# Predict the output class using the trained model
model_score = model.predict(input)

# Obtain LSTM activations by predicting the output again using the intermediate model
lstm_activations = intermediate_model.predict(input)

# Use the intermediate LSTM activations and the trained model attention layer weights and biases to calculate the attention vector.  
# Maths from the custom Attention Layer (heavily modified for the sake of brevity)
eij = tf.keras.backend.dot(lstm_activations, attention_weights)
a = tf.keras.backend.exp(eij)
attention_vector = a

I think I should be able to include the attention vector as part of the model output, but I'm struggling with figuring out how to accomplish this. Ideally I'd extract the attention vector from the custom attention layer in a single forward pass rather than extracting the various intermediate model values and calculating a second time.

For example:

model_score = model.predict(input)

model_score[0] # The predicted class label or probability
model_score[1] # The attention vector, a

I think I'm missing some basic knowledge around how Tensorflow/Keras throw variables around and when/how I can access those values to include as model output. Any advice would be appreciated.

like image 624
Marcin Avatar asked Oct 16 '22 09:10

Marcin


1 Answers

After a little more research I've managed to cobble together a working solution. I'll summarize here for any future weary internet travelers that come across this post.

The first clues came from this github thread. The attention layer defined there seems to build on the attention layer in the previously mentioned Kaggle kernel. The github user adds a return_attention flag to the layer init which, when enabled, includes the attention vector in addition to the weighted RNN output vector in the layer output.

I also added a get_config function suggested by this user in the same github thread which enables us to save and reload trained models. I had to add the return_attention flag to get_config, otherwise TF would throw a list iteration error when trying to load a saved model with return_attention=True.

With those changes made, the model definition needed to be updated to capture the additional layer outputs.

inputs = Input(shape=(max_sequence_length,))
lstm = Bidirectional(LSTM(lstm1_units, return_sequences=True))(inputs)
# Added 'attention_vector' to capture the second layer output
attention, attention_vector = Attention(max_sequence_length, return_attention=True)(lstm)
x = Dense(dense_units, activation="softmax")(attention)

The final, and most important piece of the puzzle came from this Stackoverflow answer. The method described there allows us to output multiple results while only optimizing on one of them. The code changes are subtle, but very important. I've added comments below in the spots I made changes to implement this functionality.

model = Model(
    inputs=inputs,
    outputs=[x, attention_vector] # Original value:  outputs=x
    )

model.compile(
    loss=['categorical_crossentropy', None], # Original value: loss='categorical_crossentropy'
    optimizer=optimizer,
    metrics=[BinaryAccuracy(name='accuracy')])

With those changes in place, I retrained the model and voila! The output of model.predict() is now a list containing the score and its associated attention vector.

The results of the change were pretty dramatic. Running inference on 10k examples took about 20 minutes using this new method. The old method utilizing intermediate models took ~33 minutes to perform inference on the same dataset.

And for anyone that's interested, here is my modified Attention layer:

from tensorflow.python.keras.layers import Layer
from tensorflow.keras import initializers, regularizers, constraints
from tensorflow.keras import backend as K


class Attention(Layer):
    def __init__(self, step_dim,
                W_regularizer=None, b_regularizer=None,
                W_constraint=None, b_constraint=None,
                bias=True, return_attention=True, **kwargs):
        self.supports_masking = True
        self.init = initializers.get('glorot_uniform')

        self.W_regularizer = regularizers.get(W_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias

        self.step_dim = step_dim
        self.features_dim = 0
        self.return_attention = return_attention
        super(Attention, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3

        self.W = self.add_weight(shape=(input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        self.features_dim = input_shape[-1]

        if self.bias:
            self.b = self.add_weight(shape=(input_shape[1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)
        else:
            self.b = None

        self.built = True

    def compute_mask(self, input, input_mask=None):
        return None

    def call(self, x, mask=None):
        features_dim = self.features_dim
        step_dim = self.step_dim

        eij = K.reshape(K.dot(K.reshape(x, (-1, features_dim)),
                              K.reshape(self.W, (features_dim, 1))), (-1, step_dim))

        if self.bias:
            eij += self.b

        eij = K.tanh(eij)

        a = K.exp(eij)

        if mask is not None:
            a *= K.cast(mask, K.floatx())

        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())

        a = K.expand_dims(a)
        weighted_input = x * a
        result = K.sum(weighted_input, axis=1)

        if self.return_attention:
            return [result, a]
        return result

    def compute_output_shape(self, input_shape):
        if self.return_attention:
            return [(input_shape[0], self.features_dim),
                    (input_shape[0], input_shape[1])]
        else:
            return input_shape[0], self.features_dim

    def get_config(self):
        config = {
            'step_dim': self.step_dim,
            'W_regularizer': regularizers.serialize(self.W_regularizer),
            'b_regularizer': regularizers.serialize(self.b_regularizer),
            'W_constraint': constraints.serialize(self.W_constraint),
            'b_constraint': constraints.serialize(self.b_constraint),
            'bias': self.bias,
            'return_attention': self.return_attention
        }

        base_config = super(Attention, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))
like image 152
Marcin Avatar answered Oct 20 '22 10:10

Marcin