I am creating neural nets with Tensorflow
and skflow
; for some reason I want to get the values of some inner tensors for a given input, so I am using myClassifier.get_layer_value(input, "tensorName")
, myClassifier
being a skflow.estimators.TensorFlowEstimator
.
However, I find it difficult to find the correct syntax of the tensor name, even knowing its name (and I'm getting confused between operation and tensors), so I'm using tensorboard to plot the graph and look for the name.
Is there a way to enumerate all the tensors in a graph without using tensorboard?
Finally, we close the TensorFlow session to release the TensorFlow resources we used within the session. That is how you can get a TensorFlow tensor by name using the tf. get_default_graph operation and then the TensorFlow get_tensor_by_name operation.
Returns the default graph for the current thread.
Graphs are data structures that contain a set of tf. Operation objects, which represent units of computation; and tf. Tensor objects, which represent the units of data that flow between operations. They are defined in a tf. Graph context.
The easiest[A] way to evaluate the actual value of a Tensor object is to pass it to the Session. run() method, or call Tensor. eval() when you have a default session (i.e. in a with tf. Session(): block, or see below).
You can do
[n.name for n in tf.get_default_graph().as_graph_def().node]
Also, if you are prototyping in an IPython notebook, you can show the graph directly in notebook, see show_graph
function in Alexander's Deep Dream notebook
I'll try to summarize the answers:
To get all nodes in the graph: (type tensorflow.core.framework.node_def_pb2.NodeDef
)
all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]
To get all ops in the graph: (type tensorflow.python.framework.ops.Operation
)
all_ops = tf.get_default_graph().get_operations()
To get all variables in the graph: (type tensorflow.python.ops.resource_variable_ops.ResourceVariable
)
all_vars = tf.global_variables()
To get all tensors in the graph: (type tensorflow.python.framework.ops.Tensor
)
all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]
To get all placeholders in the graph: (type tensorflow.python.framework.ops.Tensor
)
all_placeholders = [placeholder for op in tf.get_default_graph().get_operations() if op.type=='Placeholder' for placeholder in op.values()]
Tensorflow 2
To get the graph in Tensorflow 2, instead of tf.get_default_graph()
you need to instantiate a tf.function
first and access the graph
attribute, for example:
graph = func.get_concrete_function().graph
where func
is a tf.function
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With