Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow Changing Batch Size for RNN During Text Generation

Tags:

I built a vanilla character level RNN and trained it on some data. Everything worked fine up till there.

But now I want to use the model to generate text. The problem is that during this text-generation phase, the batch_size is 1, and the num_steps per batch are also different.

This is leading to several errors and I tried some hacky fixes but they aren't working. What's the usual way to deal with this?

Edit: More specifically my input placeholders have a shape of [None, num_steps], but the problem is with the initial state which doesn't accept a shape of [None, hidden_size].

like image 271
Silver Avatar asked Nov 05 '16 12:11

Silver


1 Answers

I have dealt with this same problem. There are two issues you will need to deal with. The first is adjusting the batch size and step size to 1. You can easily do this by setting the batch and length dimensions in the input sequence to none. Ie [None, None, 128], the 128 represents the 128 ascii characters (although you could probably use less since you probably only need a subset of the characters.)

Dealing with the initial state is the most trickey. This is because you need to save it between calls to session.run(). Since your num_steps is one, and it is initialized to zero at the beginning of each step. What I recommend doing is allowing the initial state to be passed as a placeholder and returned from session.run(). This way the user of the model can continue the current state between batches. The easiest way to do this is to make sure state_is_tupel is set to False for every RNN you use, and you will simply get a final state tensor back from the dynamic RNN function.

I personally don't like setting state_is_tupel to False since it is deprecated so I wrote my own code to flatten the state tupel. The following code is from my project to generate sound.

        batch_size = tf.shape(self.input_sound)[0]
        rnn = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(self.hidden_size) for _ in range(self.n_hidden)])  
        zero_state = pack_state_tupel(rnn.zero_state(batch_size, tf.float32))
        self.input_state = tf.placeholder_with_default(zero_state, None)
        state = unpack_state_tupel(self.input_state, rnn.state_size)

        rnn_input_seq = tf.cond(self.is_training, lambda: self.input_sound[:, :-1], lambda: self.input_sound)
        output, final_state = tf.nn.dynamic_rnn(rnn, rnn_input_seq, initial_state = state)

        with tf.variable_scope('output_layer'):
            output = tf.reshape(output, (-1, self.hidden_size))
            W = tf.get_variable('W', (self.hidden_size, self.sample_length))
            b = tf.get_variable('b', (self.sample_length,))
            output = tf.matmul(output, W) + b
            output = tf.reshape(output, (batch_size, -1, self.sample_length))


        self.output_state = pack_state_tupel(final_state)
        self.output_sound = output

It uses the following two functions which should work for any type of RNN although I only tested it with this model.

def pack_state_tupel(state_tupel):
    if isinstance(state_tupel, tf.Tensor) or not hasattr(state_tupel, '__iter__'):
        return state_tupel
    else:
        return tf.concat(1, [pack_state_tupel(item) for item in state_tupel])

def unpack_state_tupel(state_tensor, sizes):
    def _unpack_state_tupel(state_tensor_, sizes_, offset_):
        if isinstance(sizes_, tf.Tensor) or not hasattr(sizes_, '__iter__'): 
            return tf.reshape(state_tensor_[:, offset_ : offset_ + sizes_], (-1, sizes_)), offset_ + sizes_
        else:
            result = []
            for size in sizes_:
                s, offset_ = _unpack_state_tupel(state_tensor_, size, offset_)
                result.append(s)
            if isinstance(sizes_, tf.nn.rnn_cell.LSTMStateTuple):
                return tf.nn.rnn_cell.LSTMStateTuple(*result), offset_
            else:
                return tuple(result), offset_
    return _unpack_state_tupel(state_tensor, sizes, 0)[0]

Finally in my generate function see how I manage the hidden state s.

def generate(self, seed, steps):
    def _step(x, s = None):
        feed_dict = {self.input_sound: np.reshape(x, (1, -1, self.sample_length))}
        if s is not None:
            feed_dict[self.input_state] = s
        return self.session.run([self.output_sound, self.output_state], feed_dict)

    seed_pad = self.sample_length - len(seed) % self.sample_length
    if seed_pad: seed = np.pad(seed, (seed_pad, 0), 'constant')

    y, s = _step(seed)
    y = y[:, -1:]

    result = [seed, y.flatten()]
    for _ in range(steps):
        y, s = _step(y, s)
        result.append(y.flatten())

    return np.concatenate(result) 
like image 98
chasep255 Avatar answered Sep 23 '22 16:09

chasep255