Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

In Tensorflow, get the names of all the Tensors in a graph

I am creating neural nets with Tensorflow and skflow; for some reason I want to get the values of some inner tensors for a given input, so I am using myClassifier.get_layer_value(input, "tensorName"), myClassifier being a skflow.estimators.TensorFlowEstimator.

However, I find it difficult to find the correct syntax of the tensor name, even knowing its name (and I'm getting confused between operation and tensors), so I'm using tensorboard to plot the graph and look for the name.

Is there a way to enumerate all the tensors in a graph without using tensorboard?

like image 490
P. Camilleri Avatar asked Oct 19 '22 09:10

P. Camilleri


People also ask

How do I get a TensorFlow tensor name?

Finally, we close the TensorFlow session to release the TensorFlow resources we used within the session. That is how you can get a TensorFlow tensor by name using the tf. get_default_graph operation and then the TensorFlow get_tensor_by_name operation.

What does tf Get_default_graph () do?

Returns the default graph for the current thread.

What is tensor graph?

Graphs are data structures that contain a set of tf. Operation objects, which represent units of computation; and tf. Tensor objects, which represent the units of data that flow between operations. They are defined in a tf. Graph context.

How do you find the value of tensors?

The easiest[A] way to evaluate the actual value of a Tensor object is to pass it to the Session. run() method, or call Tensor. eval() when you have a default session (i.e. in a with tf. Session(): block, or see below).


2 Answers

You can do

[n.name for n in tf.get_default_graph().as_graph_def().node]

Also, if you are prototyping in an IPython notebook, you can show the graph directly in notebook, see show_graph function in Alexander's Deep Dream notebook

like image 200
Yaroslav Bulatov Avatar answered Oct 21 '22 22:10

Yaroslav Bulatov


I'll try to summarize the answers:

To get all nodes in the graph: (type tensorflow.core.framework.node_def_pb2.NodeDef)

all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]

To get all ops in the graph: (type tensorflow.python.framework.ops.Operation)

all_ops = tf.get_default_graph().get_operations()

To get all variables in the graph: (type tensorflow.python.ops.resource_variable_ops.ResourceVariable)

all_vars = tf.global_variables()

To get all tensors in the graph: (type tensorflow.python.framework.ops.Tensor)

all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]

To get all placeholders in the graph: (type tensorflow.python.framework.ops.Tensor)

all_placeholders = [placeholder for op in tf.get_default_graph().get_operations() if op.type=='Placeholder' for placeholder in op.values()]

Tensorflow 2

To get the graph in Tensorflow 2, instead of tf.get_default_graph() you need to instantiate a tf.function first and access the graph attribute, for example:

graph = func.get_concrete_function().graph

where func is a tf.function

like image 42
Szabolcs Avatar answered Oct 21 '22 23:10

Szabolcs