Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to visualize attention weights from AttentionWrapper

I want to visualize attention scores in tensorflow latest version(1.2). I use AttentionWrapper in contrib.seq2seq to build a RNNCell, with BasicDecoder as decoder, then use dynamic_decode() to generate outputs step by step.

How could I access attention weights of all steps? Thanks!

like image 379
宋卫平 Avatar asked Jun 18 '17 08:06

宋卫平


People also ask

How do you visualize a Bert model?

BertViz is an interactive tool for visualizing attention in Transformer language models such as BERT, GPT2, or T5. It can be run inside a Jupyter or Colab notebook through a simple Python API that supports most Huggingface models.

How do you get your attention score?

Take the query vector for a word and calculate it's dot product with the transpose of the key vector of each word in the sequence — including itself. This is the attention score or attention weight . 2. Then divide each of the results by the square root of the dimension of the key vector.


1 Answers

You can access attention weights by setting alignment_history=True flag in AttentionWrapper definition.

Here is the example:

# Define attention mechanism
attn_mech = tf.contrib.seq2seq.LuongMonotonicAttention(
    num_units = attention_unit_size, memory = decoder_outputs,
    memory_sequence_length = input_lengths)

# Define attention cell
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
    cell = decoder_cell, attention_mechanism = attn_mech,
    alignment_history=True)

# Define train helper
train_helper = tf.contrib.seq2seq.TrainingHelper(
    inputs = encoder_inputs, 
    sequence_length = input_lengths)

# Define decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
    cell = attn_cell, 
    helper = train_helper, initial_state=decoder_initial_state)

# Dynamic decoding
dec_outputs, dec_states, _ = tf.contrib.seq2seq.dynamic_decode(decoder)

And then inside the session, you can access the weights as below:

with tf.Session() as sess:
    ...
    alignments = sess.run(dec_states.alignment_history.stack(), feed_dict)

Finally, you can visualize attentions (alignments) like this:

def plot_attention(attention_map, input_tags = None, output_tags = None):    
    attn_len = len(attention_map)

    # Plot the attention_map
    plt.clf()
    f = plt.figure(figsize=(15, 10))
    ax = f.add_subplot(1, 1, 1)

    # Add image
    i = ax.imshow(attention_map, interpolation='nearest', cmap='Blues')

    # Add colorbar
    cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
    cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
    cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2)

    # Add labels
    ax.set_yticks(range(attn_len))
    if output_tags != None:
      ax.set_yticklabels(output_tags[:attn_len])

    ax.set_xticks(range(attn_len))
    if input_tags != None:
      ax.set_xticklabels(input_tags[:attn_len], rotation=45)

    ax.set_xlabel('Input Sequence')
    ax.set_ylabel('Output Sequence')

    # add grid and legend
    ax.grid()

    plt.show()

# input_tags - word representation of input sequence, use None to skip
# output_tags - word representation of output sequence, use None to skip
# i - index of input element in batch

plot_attention(alignments[:, i, :], input_tags, output_tags)

enter image description here

like image 144
aboev Avatar answered Oct 17 '22 04:10

aboev