I want to visualize attention scores in tensorflow latest version(1.2). I use AttentionWrapper in contrib.seq2seq to build a RNNCell, with BasicDecoder as decoder, then use dynamic_decode() to generate outputs step by step.
How could I access attention weights of all steps? Thanks!
BertViz is an interactive tool for visualizing attention in Transformer language models such as BERT, GPT2, or T5. It can be run inside a Jupyter or Colab notebook through a simple Python API that supports most Huggingface models.
Take the query vector for a word and calculate it's dot product with the transpose of the key vector of each word in the sequence — including itself. This is the attention score or attention weight . 2. Then divide each of the results by the square root of the dimension of the key vector.
You can access attention weights by setting alignment_history=True flag in AttentionWrapper definition.
Here is the example:
# Define attention mechanism
attn_mech = tf.contrib.seq2seq.LuongMonotonicAttention(
num_units = attention_unit_size, memory = decoder_outputs,
memory_sequence_length = input_lengths)
# Define attention cell
attn_cell = tf.contrib.seq2seq.AttentionWrapper(
cell = decoder_cell, attention_mechanism = attn_mech,
alignment_history=True)
# Define train helper
train_helper = tf.contrib.seq2seq.TrainingHelper(
inputs = encoder_inputs,
sequence_length = input_lengths)
# Define decoder
decoder = tf.contrib.seq2seq.BasicDecoder(
cell = attn_cell,
helper = train_helper, initial_state=decoder_initial_state)
# Dynamic decoding
dec_outputs, dec_states, _ = tf.contrib.seq2seq.dynamic_decode(decoder)
And then inside the session, you can access the weights as below:
with tf.Session() as sess:
...
alignments = sess.run(dec_states.alignment_history.stack(), feed_dict)
Finally, you can visualize attentions (alignments) like this:
def plot_attention(attention_map, input_tags = None, output_tags = None):
attn_len = len(attention_map)
# Plot the attention_map
plt.clf()
f = plt.figure(figsize=(15, 10))
ax = f.add_subplot(1, 1, 1)
# Add image
i = ax.imshow(attention_map, interpolation='nearest', cmap='Blues')
# Add colorbar
cbaxes = f.add_axes([0.2, 0, 0.6, 0.03])
cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2)
# Add labels
ax.set_yticks(range(attn_len))
if output_tags != None:
ax.set_yticklabels(output_tags[:attn_len])
ax.set_xticks(range(attn_len))
if input_tags != None:
ax.set_xticklabels(input_tags[:attn_len], rotation=45)
ax.set_xlabel('Input Sequence')
ax.set_ylabel('Output Sequence')
# add grid and legend
ax.grid()
plt.show()
# input_tags - word representation of input sequence, use None to skip
# output_tags - word representation of output sequence, use None to skip
# i - index of input element in batch
plot_attention(alignments[:, i, :], input_tags, output_tags)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With