Below I have an implementation of a Tensorflow RNN Cell, designed to emulate Alex Graves' algorithm ACT in this paper: http://arxiv.org/abs/1603.08983.
At a single timestep in the sequence called via rnn.rnn(with a static sequence_length parameter, so the rnn is unrolled dynamically - I am using a fixed batch size of 20), we recursively call ACTStep, producing outputs of size(1,200) where the hidden dimension of the RNN cell is 200 and we have a batch size of 1.
Using the while loop in Tensorflow, we iterate until the accumulated halting probability is high enough. All of this works reasonably smoothly, but I am having problems accumulating states, probabilities and outputs within the while loop, which we need to do in order to create weighted combinations of these as the final cell output/state.
I have tried using a simple list, as below, but this fails when the graph is compiled as the outputs are not in the same frame(is it possible to use the "switch" function in control_flow_ops to forward the tensors to the point at which they are required, ie the add_n function just before we return the values?). I have also tried using the TensorArray structure, but I am finding this difficult to use as it seems to destroy shape information and replacing it manually hasn't worked. I also haven't been able to find much documentation on TensorArrays, presumably as they are, I imagine, mainly for internal TF use.
Any advice on how it might be possible to correctly accumulate the variables produced by ACTStep would be much appreciated.
class ACTCell(RNNCell):
"""An RNN cell implementing Graves' Adaptive Computation time algorithm"""
def __init__(self, num_units, cell, epsilon, max_computation):
self.one_minus_eps = tf.constant(1.0 - epsilon)
self._num_units = num_units
self.cell = cell
self.N = tf.constant(max_computation)
@property
def input_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
with vs.variable_scope(scope or type(self).__name__):
# define within cell constants/ counters used to control while loop
prob = tf.get_variable("prob", [], tf.float32,tf.constant_initializer(0.0))
counter = tf.get_variable("counter", [],tf.float32,tf.constant_initializer(0.0))
tf.assign(prob,0.0)
tf.assign(counter, 0.0)
# the predicate for stopping the while loop. Tensorflow demands that we have
# all of the variables used in the while loop in the predicate.
pred = lambda prob,counter,state,input,\
acc_state,acc_output,acc_probs:\
tf.logical_and(tf.less(prob,self.one_minus_eps), tf.less(counter,self.N))
acc_probs = []
acc_outputs = []
acc_states = []
_,iterations,_,_,acc_states,acc_output,acc_probs = \
control_flow_ops.while_loop(pred,
self.ACTStep,
[prob,counter,state,input,acc_states,acc_outputs,acc_probs])
# TODO:fix last part of this, need to use the remainder.
# TODO: find a way to accumulate the regulariser
# here we take a weighted combination of the states and outputs
# to use as the actual output and state which is passed to the next timestep.
next_state = tf.add_n([tf.mul(x,y) for x,y in zip(acc_probs,acc_states)])
output = tf.add_n([tf.mul(x,y) for x,y in zip(acc_probs,acc_outputs)])
return output, next_state
def ACTStep(self,prob,counter,state,input, acc_states,acc_outputs,acc_probs):
output, new_state = rnn.rnn(self.cell, [input], state, scope=type(self.cell).__name__)
prob_w = tf.get_variable("prob_w", [self.cell.input_size,1])
prob_b = tf.get_variable("prob_b", [1])
p = tf.nn.sigmoid(tf.matmul(prob_w,new_state) + prob_b)
acc_states.append(new_state)
acc_outputs.append(output)
acc_probs.append(p)
return [tf.add(prob,p),tf.add(counter,1.0),new_state, input,acc_states,acc_outputs,acc_probs]
I'm going to preface this response that this is NOT a complete solution, but rather some commentary on how to improve your cell.
To start off, in your ACTStep function, you call rnn.rnn
for one timestep (as defined by [input]
. If you're doing a single timestep, it is probably more efficient to simple use the actual self.cell
call function. You'll see this same mechanism used in tensorflow rnncell wrappers
You mentioned that you have tried using TensorArrays
. Did you pack and unpack the tensorarrays appropriately? Here is a repo where you'll find under model.py
the tensorarrays are packed and unpacked properly.
You also asked if there is a function in control_flow_ops
that will require all the tensors to be accumulated. I think you are looking for tf.control_dependencies
You can list all of your output tensors operations in control_dependicies and that will require tensorflow to compute all tensors up into that point.
Also, it looks like your counter
variable is trainable. Are you sure you want this to be the case? If you're adding plus one to your counter, that probably wouldn't yield the correct result. On the other hand, you could have purposely kept it trainable to differentiate it at the end for the ponder cost function.
Also I believe the Remainder function should be in your script:
remainder = 1.0 - tf.add_n(acc_probs[:-1])
#note that there is a -1 in the list as you do not want to grab the last probability
Here is my version of your code edited:
class ACTCell(RNNCell):
"""An RNN cell implementing Graves' Adaptive Computation time algorithm
Notes: https://www.evernote.com/shard/s189/sh/fd165646-b630-48b7-844c-86ad2f07fcda/c9ab960af967ef847097f21d94b0bff7
"""
def __init__(self, num_units, cell, max_computation = 5.0, epsilon = 0.01):
self.one_minus_eps = tf.constant(1.0 - epsilon) #episolon is 0.01 as found in the paper
self._num_units = num_units
self.cell = cell
self.N = tf.constant(max_computation)
@property
def input_size(self):
return self._num_units
@property
def output_size(self):
return self._num_units
@property
def state_size(self):
return self._num_units
def __call__(self, inputs, state, scope=None):
with vs.variable_scope(scope or type(self).__name__):
# define within cell constants/ counters used to control while loop
prob = tf.constant(0.0, shape = [batch_size])
counter = tf.constant(0.0, shape = [batch_size])
# the predicate for stopping the while loop. Tensorflow demands that we have
# all of the variables used in the while loop in the predicate.
pred = lambda prob,counter,state,input,acc_states,acc_output,acc_probs:\
tf.logical_and(tf.less(prob,self.one_minus_eps), tf.less(counter,self.N))
acc_probs, acc_outputs, acc_states = [], [], []
_,iterations,_,_,acc_states,acc_output,acc_probs = \
control_flow_ops.while_loop(
pred,
self.ACTStep, #looks like he purposely makes the while loop here
[prob, counter, state, input, acc_states, acc_outputs, acc_probs])
'''mean-field updates for states and outputs'''
next_state = tf.add_n([x*y for x,y in zip(acc_probs,acc_states)])
output = tf.add_n([x*y for x,y in zip(acc_probs,acc_outputs)])
remainder = 1.0 - tf.add_n(acc_probs[:-1]) #you take the last off to avoid a negative ponder cost #the problem here is we need to take the sum of all the remainders
tf.add_to_collection("ACT_remainder", remainder) #if this doesnt work then you can do self.list based upon timesteps
tf.add_to_collection("ACT_iterations", iterations)
return output, next_state
def ACTStep(self,prob, counter, state, input, acc_states, acc_outputs, acc_probs):
'''run rnn once'''
output, new_state = rnn.rnn(self.cell, [input], state, scope=type(self.cell).__name__)
prob_w = tf.get_variable("prob_w", [self.cell.input_size,1])
prob_b = tf.get_variable("prob_b", [1])
halting_probability = tf.nn.sigmoid(tf.matmul(prob_w,new_state) + prob_b)
acc_states.append(new_state)
acc_outputs.append(output)
acc_probs.append(halting_probability)
return [p + prob, counter + 1.0, new_state, input,acc_states,acc_outputs,acc_probs]
def PonderCostFunction(self, time_penalty = 0.01):
'''
note: ponder is completely different than probability and ponder = roe
the ponder cost function prohibits the rnn from cycling endlessly on each timestep when not much is needed
'''
n_iterations = tf.get_collection_ref("ACT_iterations")
remainder = tf.get_collection_ref("ACT_remainder")
return tf.reduce_sum(n_iterations + remainder) #completely different from probability
This is a complicated paper to implement that I have been working on myself. I wouldn't mind collaborating with you to get it done in Tensorflow. If you're interested, please add me at LeavesBreathe on Skype and we can go from there.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With