Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.cond lowers the training speed

I was using tensorflow input pipelines like cifar10 model in tensorflow and try to use tf.cond to do validation and I wrote something like this

train_data = model.input(istrain=True)
val_data = model.input(istrain=False)

# This selects which stream to use.
select_val = tf.placeholder(dtype=bool,shape=[],name='select_test')
data = tf.cond(
    select_val,
    lambda:val_data,
    lambda:train_data
)

# Here is the model.
loss = ...
train_op = ...
...

with tf.Session():
    ...

And if I delete the cond and just use the training data, the speed is 4000 samples/s and if I use the code above, the speed decrease to 2300 samples/s. The validation pipeline capacity is set really small so it won't take too much memory in GPU. The frequency of doing validation is also really low. I'm not sure what is going wrong and please help me out.

like image 456
zhr1201 Avatar asked Mar 30 '17 02:03

zhr1201


1 Answers

tf.cond is not fully lazy. Any operations that are required by either of the branches of the cond will be run even if the branch that requires it is not the branch to be executed. So in your case, both model.input(istrain=True) and model.input(istrain=False) are being execute every time your data op is being called. The results of one of them is just ignored.

The documentation for cond gives a minimal code example:

Note that the conditional execution applies only to the operations defined in fn1 and fn2. Consider the following simple program:

z = tf.multiply(a, b)
result = tf.cond(x < y, lambda: tf.add(x, z), lambda: tf.square(y))

If x < y, the tf.add operation will be executed and tf.square operation will not be executed. Since z is needed for at least one branch of the cond, the tf.mul operation is always executed, unconditionally. Although this behavior is consistent with the dataflow model of TensorFlow, it has occasionally surprised some users who expected a lazier semantics.

Also note, this means that if your model.input is pulling some set of data from a larger pool (say, a batch from an entire dataset), each time the cond is run, data gets pulled from both validation and training, and one set just gets thrown away. This can cause problems more serious than inefficiencies in some cases. For example, if you're processing only a certain number epochs, then with this code you're not actually processing that number of epochs because data was being pulled that was not used.

like image 178
golmschenk Avatar answered Nov 03 '22 01:11

golmschenk