Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does this tensorflow loop require so much memory?

I have a contrived version of a complicated network:

import tensorflow as tf

a = tf.ones([1000])
b = tf.ones([1000])

for i in range(int(1e6)):
    a = a * b

My intuition is that this should require very little memory. Just the space for the initial array allocation and a string of commands that utilizes the nodes and overwrites the memory stored in tensor 'a' at each step. But memory usage grows quite rapidly.

What is going on here, and how can I decrease memory usage when I compute a tensor and overwrite it a bunch of times?

Edit:

Thanks to Yaroslav's suggestions the solution turned out to be using a while_loop to minimize the number of nodes on the graph. This works great and is much faster, requires far less memory, and is all contained in-graph.

import tensorflow as tf

a = tf.ones([1000])
b = tf.ones([1000])

cond = lambda _i, _1, _2: tf.less(_i, int(1e6))
body = lambda _i, _a, _b: [tf.add(_i, 1), _a * _b, _b]

i = tf.constant(0)
output = tf.while_loop(cond, body, [i, a, b])

with tf.Session() as sess:
    result = sess.run(output)
    print(result)
like image 762
jstaker7 Avatar asked Dec 25 '22 03:12

jstaker7


1 Answers

Your a*b command translates to tf.mul(a, b), which is equivalent to tf.mul(a, b, g=tf.get_default_graph()). This command adds a Mul node to the current Graph object, so you are trying to add 1 million Mul nodes to the current graph. That's also problematic since you can't serialize Graph object larger than 2GB, there are some checks that may fail once you are dealing with such a large graph.

I'd recommend reading Programming Models for Deep Learning by MXNet folks. TensorFlow is "symbolic" programming in their terminology, and you are treating it as imperative.

To get what you want using Python loop you could construct multiplication op once, and run it repeatedly, using feed_dict to feed updates

mul_op = a*b
result = sess.run(a)
for i in range(int(1e6)):
  result = sess.run(mul_op, feed_dict={a: result})

For more efficiency you could use tf.Variable objects and var.assign to avoid Python<->TensorFlow data transfers

like image 95
Yaroslav Bulatov Avatar answered Dec 30 '22 20:12

Yaroslav Bulatov