Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to access tensor_content values in TensorProto in TensorFlow?

Tags:

tensorflow

Similar to How to access values in protos in TensorFlow? but doesn't cater for this case.

I see a bytes tensor_content attribute in TensorProto. I'm trying to get information about the nodes through:

for node in tf.get_default_graph().as_graph_def().node: node.attr['value'].tensor.tensor_content # decode these bytes

For information, the print of a node looks something like this:

name: "conv2d/convolution/Shape"
op: "Const"
device: "/device:GPU:0"
attr {
  key: "dtype"
  value {
    type: DT_INT32
  }
}
attr {
  key: "value"
  value {
    tensor {
      dtype: DT_INT32
      tensor_shape {
        dim {
          size: 4
        }
      }
      tensor_content: "\003\000\000\000\003\000\000\000\001\000\000\000 \000\000\000"
    }
  }
}
like image 824
jkschin Avatar asked May 23 '17 03:05

jkschin


2 Answers

from tensorflow.python.framework import tensor_util

for n in tf.get_default_graph().as_graph_def().node:
    print tensor_util.MakeNdarray(n.attr['value'].tensor)
like image 154
jkschin Avatar answered Oct 12 '22 22:10

jkschin


Decode tensor_array bytes and then reshape with given shape:

for node in tf.get_default_graph.as_graph_def().node:
    tensor_bytes = node.attr["value"].tensor.tensor_content
    tensor_dtype = node.attr["value"].tensor.dtype
    tensor_shape = [x.size for x in node.attr["value"].tensor.tensor_shape.dim]
    tensor_array = tf.decode_raw(tensor_bytes, tensor_dtype)
    tensor_array = tf.reshape(tensor_array, tensor_shape)
like image 33
zong fan Avatar answered Oct 12 '22 22:10

zong fan