I upgraded to Tensorflow 2.0 and there is no tf.summary.FileWriter("tf_graphs", sess.graph)
. I was looking through some other StackOverflow questions on this and they said to use tf.compat.v1.summary etc
. Surely there must be a way to graph and visualize a tf.keras model in Tensorflow version 2. What is it? I'm looking for a tensorboard output like the one below. Thank you!
You create and run a graph in TensorFlow by using tf. function , either as a direct call or as a decorator. tf. function takes a regular function as input and returns a Function .
pb tensorflow model you can use: inspect_pb.py to print model info or use tensorflow summarize_graph tool with --print_structure flag, also it's nice that it can detect input and output names.
You can visualize the graph of any tf.function
decorated function, but first, you have to trace its execution.
Visualizing the graph of a Keras model means to visualize it's call
method.
By default, this method is not tf.function
decorated and therefore you have to wrap the model call in a function correctly decorated and execute it.
import tensorflow as tf
model = tf.keras.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(32, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
@tf.function
def traceme(x):
return model(x)
logdir = "log"
writer = tf.summary.create_file_writer(logdir)
tf.summary.trace_on(graph=True, profiler=True)
# Forward pass
traceme(tf.zeros((1, 28, 28, 1)))
with writer.as_default():
tf.summary.trace_export(name="model_trace", step=0, profiler_outdir=logdir)
According to the docs, you can use Tensorboard to visualise graphs once your model has been trained.
First, define your model and run it. Then, open Tensorboard and switch to the Graph tab.
Minimal Compilable Example
This example is taken from the docs. First, define your model and data.
# Relevant imports.
%load_ext tensorboard
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from datetime import datetime
from packaging import version
import tensorflow as tf
from tensorflow import keras
# Define the model.
model = keras.models.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation='softmax')
])
model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
(train_images, train_labels), _ = keras.datasets.fashion_mnist.load_data()
train_images = train_images / 255.0
Next, train your model. Here, you will need to define a callback for Tensorboard to use for visualising stats and graphs.
# Define the Keras TensorBoard callback.
logdir="logs/fit/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
# Train the model.
model.fit(
train_images,
train_labels,
batch_size=64,
epochs=5,
callbacks=[tensorboard_callback])
After training, in your notebook, run
%tensorboard --logdir logs
And switch to the Graph tab in the navbar:
You will see a graph that looks a lot like this:
Here's what is working for me at the moment (TF 2.0.0), based on the tf.keras.callbacks.TensorBoard code:
# After model has been compiled
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.keras.backend import get_graph
tb_path = '/tmp/tensorboard/'
tb_writer = tf.summary.create_file_writer(tb_path)
with tb_writer.as_default():
if not model.run_eagerly:
summary_ops_v2.graph(get_graph(), step=0)
Here is the solution for tf2.x with Graph visualization of subclassed model/layer
import tensorflow as tf
print("TensorFlow version:", tf.__version__)
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model,Input
class MyModel(Model):
def __init__(self, dim):
super(MyModel, self).__init__()
self.conv1 = Conv2D(16, 3, activation='relu')
self.conv2 = Conv2D(32, 3, activation='relu')
self.conv3 = Conv2D(8, 3, activation='relu')
self.flatten = Flatten()
self.d1 = Dense(128, activation='relu')
self.d2 = Dense(1)
def call(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.flatten(x)
x = self.d1(x)
return self.d2(x)
def build_graph(self):
x = Input(shape=(dim))
return Model(inputs=[x], outputs=self.call(x))
dim = (28, 28, 1)
# Create an instance of the model
model = MyModel((dim))
model.build((None, *dim))
model.build_graph().summary()
tf.keras.utils.plot_model(model.build_graph(), to_file="model.png",
expand_nested=True, show_shapes=True)
the output is
TensorFlow version: 2.5.0
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 28, 28, 1)] 0
_________________________________________________________________
conv2d (Conv2D) (None, 26, 26, 16) 160
_________________________________________________________________
conv2d_1 (Conv2D) (None, 24, 24, 32) 4640
_________________________________________________________________
conv2d_2 (Conv2D) (None, 22, 22, 8) 2312
_________________________________________________________________
flatten (Flatten) (None, 3872) 0
_________________________________________________________________
dense (Dense) (None, 128) 495744
_________________________________________________________________
dense_1 (Dense) (None, 1) 129
=================================================================
Total params: 502,985
Trainable params: 502,985
Non-trainable params: 0
Here is also a graph visualization
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With