Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Wrapping Tensorflow For Use in Keras

I'm using Keras for the rest of my project, but also hoping to make use of the Bahdanau attention module that Tensorflow has implemented (see tf.contrib.seq2seq.BahdanauAttention). I've been attempting to implement this via the Keras Layer convention, but not sure whether this is an appropriate fit.

Is there some convention for wrapping Tensorflow components in this way to be compatible with the computation graph?

I've included the code that I've written thus far (not working yet) and would appreciate any pointers.

from keras import backend as K
from keras.engine.topology import Layer
from keras.models import Model
import numpy as np
import tensorflow as tf

class BahdanauAttention(Layer):

# The Bahdanau attention layer has to attend to a particular set of memory states
# These are usually the output of some encoder process, where we take the output of
# GRU states
def __init__(self, memory, num_units, **kwargs):
    self.memory = memory
    self.num_units = num_units
    super(BahdanauAttention, self).__init__(**kwargs)

def build(self, input_shape):
    # The attention component will be in control of attending to the given memory
    attention = tf.contrib.seq2seq.BahdanauAttention(self.num_units, self.memory)
    cell = tf.contrib.rnn.GRUCell(num_units)

    cell_with_attention = tf.contrib.seq2seq.DynamicAttentionWrapper(cell, attention, num_units)
    self.outputs, _ = tf.nn.dynamic_rnn(cell_with_attention, inputs, dtype=tf.float32)

    super(MyLayer, self).build(input_shape)

def call(self, x):
    return

def compute_output_shape(self, input_shape):
    return (input_shape[0], self.memory[1], self.num_units)
like image 269
PF1 Avatar asked Jun 01 '17 06:06

PF1


People also ask

Is TensorFlow a wrapper on Keras?

SciANN: A Keras/TensorFlow wrapper for scientific computations and physics-informed deep learning using artificial neural networks.

What is Keras wrapper?

"The keras. models. Sequential class is a wrapper for the neural network model that treats the network as a sequence of layers. It implements the Keras model interface with common methods like compile(), fit(), and evaluate() that are used to train and run the model."

Do I need to import TensorFlow to use Keras?

It's not necessary to import all of the Keras and Tensorflow library functions. Instead, import just the function(s) you need for your project.

How does Keras use TensorFlow?

Keras is the high-level API of TensorFlow 2: an approachable, highly-productive interface for solving machine learning problems, with a focus on modern deep learning. It provides essential abstractions and building blocks for developing and shipping machine learning solutions with high iteration velocity.


1 Answers

The newer version of Keras uses tf.keras.layers.AdditiveAttention(). This should work off the shelf.

Alternatively a custom Bahdanau layer can be written as shown in half a dozen lines of code: Custom Attention Layer using in Keras

like image 158
Allohvk Avatar answered Nov 15 '22 07:11

Allohvk