Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How visualize attention LSTM using keras-self-attention package?

I'm using (keras-self-attention) to implement attention LSTM in KERAS. How can I visualize the attention part after training the model? This is a time series forecasting case.

from keras.models import Sequential
from keras_self_attention import SeqWeightedAttention
from keras.layers import LSTM, Dense, Flatten

model = Sequential()
model.add(LSTM(activation = 'tanh' ,units = 200, return_sequences = True, 
               input_shape = (TrainD[0].shape[1], TrainD[0].shape[2])))
model.add(SeqSelfAttention())
model.add(Flatten())    
model.add(Dense(1, activation = 'relu'))

model.compile(optimizer = 'adam', loss = 'mse')
like image 408
Eghbal Avatar asked Oct 12 '19 17:10

Eghbal


1 Answers

One approach is to fetch the outputs of SeqSelfAttention for a given input, and organize them so to display predictions per-channel (see below). For something more advanced, have a look at the iNNvestigate library (usage examples included).

Update: I can also recommend See RNN, a package I wrote.


Explanation: show_features_1D fetches layer_name (can be a substring) layer outputs and shows predictions per-channel (labeled), with timesteps along x-axis and output values along y-axis.
  • input_data = single batch of data of shape (1, input_shape)
  • prefetched_outputs = already-acquired layer outputs; overrides input_data
  • max_timesteps = max # of timesteps to show
  • max_col_subplots = max # of subplots along horizontal
  • equate_axes = force all x- and y- axes to be equal (recommended for fair comparison)
  • show_y_zero = whether to show y=0 as a red line
  • channel_axis = layer features dimension (e.g. units for LSTM, which is last)
  • scale_width, scale_height = scale displayed image width & height
  • dpi = image quality (dots per inches)

Visuals (below) explanation:

  • First is useful to see the shapes of extracted features, regardless of magnitude - giving information about e.g. frequency contents
  • Second is useful to see feature relationships - e.g. relative magnitudes, biases, and frequencies. Below result stands in stark contrast with image above it, as, running print(outs_1) reveals that all magnitudes are very small and don't vary much, so including the y=0 point and equating axes yields a line-like visual, which can be interpreted as self-attention being bias-oriented.
  • Third is useful for visualizing features too many to be visualized as above; defining model with batch_shape instead of input_shape removes all ? in printed shapes, and we can see that first output's shape is (10, 60, 240), second's (10, 240, 240). In other words, the first output returns LSTM channel attention, and the second a "timesteps attention". The heatmap result below can be interpreted as showing attention "cooling down" w.r.t. timesteps.

SeqWeightedAttention is a lot easier to visualize, but there isn't much to visualize; you'll need to rid of Flatten above to make it work. The attention's output shapes then become (10, 60) and (10, 240) - for which you can use a simple histogram, plt.hist (just make sure you exclude the batch dimension - i.e. feed (60,) or (240,)).


from keras.layers import Input, Dense, LSTM, Flatten, concatenate
from keras.models import Model
from keras.optimizers import Adam
from keras_self_attention import SeqSelfAttention
import numpy as np 

ipt   = Input(shape=(240,4))
x     = LSTM(60, activation='tanh', return_sequences=True)(ipt)
x     = SeqSelfAttention(return_attention=True)(x)
x     = concatenate(x)
x     = Flatten()(x)
out   = Dense(1, activation='sigmoid')(x)
model = Model(ipt,out)
model.compile(Adam(lr=1e-2), loss='binary_crossentropy')

X = np.random.rand(10,240,4) # dummy data
Y = np.random.randint(0,2,(10,1)) # dummy labels
model.train_on_batch(X, Y)

outs = get_layer_outputs(model, 'seq', X[0:1], 1)
outs_1 = outs[0]
outs_2 = outs[1]

show_features_1D(model,'lstm',X[0:1],max_timesteps=100,equate_axes=False,show_y_zero=False)
show_features_1D(model,'lstm',X[0:1],max_timesteps=100,equate_axes=True, show_y_zero=True)
show_features_2D(outs_2[0])  # [0] for 2D since 'outs_2' is 3D


def show_features_1D(model=None, layer_name=None, input_data=None,
                     prefetched_outputs=None, max_timesteps=100,
                     max_col_subplots=10, equate_axes=False,
                     show_y_zero=True, channel_axis=-1,
                     scale_width=1, scale_height=1, dpi=76):
    if prefetched_outputs is None:
        layer_outputs = get_layer_outputs(model, layer_name, input_data, 1)[0]
    else:
        layer_outputs = prefetched_outputs
    n_features    = layer_outputs.shape[channel_axis]

    for _int in range(1, max_col_subplots+1):
      if (n_features/_int).is_integer():
        n_cols = int(n_features/_int)
    n_rows = int(n_features/n_cols)

    fig, axes = plt.subplots(n_rows,n_cols,sharey=equate_axes,dpi=dpi)
    fig.set_size_inches(24*scale_width,16*scale_height)

    subplot_idx = 0
    for row_idx in range(axes.shape[0]):
      for col_idx in range(axes.shape[1]): 
        subplot_idx += 1
        feature_output = layer_outputs[:,subplot_idx-1]
        feature_output = feature_output[:max_timesteps]
        ax = axes[row_idx,col_idx]

        if show_y_zero:
            ax.axhline(0,color='red')
        ax.plot(feature_output)

        ax.axis(xmin=0,xmax=len(feature_output))
        ax.axis('off')

        ax.annotate(str(subplot_idx),xy=(0,.99),xycoords='axes fraction',
                    weight='bold',fontsize=14,color='g')
    if equate_axes:
        y_new = []
        for row_axis in axes:
            y_new += [np.max(np.abs([col_axis.get_ylim() for 
                                     col_axis in row_axis]))]
        y_new = np.max(y_new)
        for row_axis in axes:
            [col_axis.set_ylim(-y_new,y_new) for col_axis in row_axis]
    plt.show()
def show_features_2D(data, cmap='bwr', norm=None,
                     scale_width=1, scale_height=1):
    if norm is not None:
        vmin, vmax = norm
    else:
        vmin, vmax = None, None  # scale automatically per min-max of 'data'

    plt.imshow(data, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.xlabel('Timesteps', weight='bold', fontsize=14)
    plt.ylabel('Attention features', weight='bold', fontsize=14)
    plt.colorbar(fraction=0.046, pad=0.04)  # works for any size plot

    plt.gcf().set_size_inches(8*scale_width, 8*scale_height)
    plt.show()

def get_layer_outputs(model, layer_name, input_data, learning_phase=1):
    outputs   = [layer.output for layer in model.layers if layer_name in layer.name]
    layers_fn = K.function([model.input, K.learning_phase()], outputs)
    return layers_fn([input_data, learning_phase])

SeqWeightedAttention example per request:

ipt   = Input(batch_shape=(10,240,4))
x     = LSTM(60, activation='tanh', return_sequences=True)(ipt)
x     = SeqWeightedAttention(return_attention=True)(x)
x     = concatenate(x)
out   = Dense(1, activation='sigmoid')(x)
model = Model(ipt,out)
model.compile(Adam(lr=1e-2), loss='binary_crossentropy')

X = np.random.rand(10,240,4) # dummy data
Y = np.random.randint(0,2,(10,1)) # dummy labels
model.train_on_batch(X, Y)

outs = get_layer_outputs(model, 'seq', X, 1)
outs_1 = outs[0][0] # additional index since using batch_shape
outs_2 = outs[1][0]

plt.hist(outs_1, bins=500); plt.show()
plt.hist(outs_2, bins=500); plt.show()
like image 91
OverLordGoldDragon Avatar answered Oct 12 '22 23:10

OverLordGoldDragon