Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to print part of a tensor using tf.Print?

Tags:

tensorflow

In TensorFlow, when a tensor has large dimensions, it is useful to only print some part of the tensor for debugging purposes, such as the diagonal of a 2-d matrix. I only know how to print the entire tensor as follows:

sess = tf.InteractiveSession()
a = tf.constant(1.0, shape=[1000, 1000])
a = tf.Print(a, [a], "print entire a\n", summarize=1000000)
b = a + 1.
ret = sess.run(b)

The above code will print the entire 'a' tensor. But I'm not sure how to print part of 'a'. For example, if I only want to print a[0,0] without doing sess.run(a), the following code won't work:

sess = tf.InteractiveSession()
a = tf.constant(1.0, shape=[1000, 1000])
a[0,0] = tf.Print(a[0,0], [a[0,0]], "print part of a\n", summarize=1000000)
b = a + 1.
ret = sess.run(b)    
like image 895
user131379 Avatar asked Oct 29 '22 00:10

user131379


1 Answers

From the documentation:

Print(
    input_,
    data,
    message=None,
    first_n=None,
    summarize=None,
    name=None
)

Prints a list of tensors. This is an identity op with the side effect of printing data when evaluating.

a should remain the same, and in the data parameter you should place the tensor that you need to print.

import tensorflow as tf

sess = tf.InteractiveSession()
a = tf.constant(1.0, shape=[1000, 1000])
a = tf.Print(a, [a[0, 0]], "Print part of a\n", summarize=100000)
b = a + 1.
ret = sess.run(b)
like image 187
Julio Daniel Reyes Avatar answered Nov 22 '22 21:11

Julio Daniel Reyes