I have a model saved in a pb file. I hope to calculate the flops of it. My example code is as follow:
import tensorflow as tf
import sys
from tensorflow.python.platform import gfile
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat
pb_file = 'themodel.pb'
run_meta = tf.RunMetadata()
with tf.Session() as sess:
print("load graph")
with gfile.FastGFile(pb_path,'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
options=tf.profiler.ProfileOptionBuilder.float_operation())
print("test flops:{:,}".format(flops.total_float_ops))
The print information is strange. My model has tens of layers, but it reports only 18 flops in the printed information. I'm quite sure the model is correctly loaded because if I try to print the names of every layer as follows:
print([n.name for n in tf.get_default_graph().as_graph_def().node])
The print information shows exactly the right network.
What's wrong with my code?
Thank you!
I think I find the reason and solution for my question. The following code can print the flops of the given pb file.
import os
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import importer
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
pb_path = 'mymodel.pb'
run_meta = tf.RunMetadata()
with tf.Graph().as_default():
output_graph_def = graph_pb2.GraphDef()
with open(pb_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = importer.import_graph_def(output_graph_def, name="")
print('model loaded!')
all_keys = sorted([n.name for n in tf.get_default_graph().as_graph_def().node])
# for k in all_keys:
# print(k)
with tf.Session() as sess:
flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
options=tf.profiler.ProfileOptionBuilder.float_operation())
print("test flops:{:,}".format(flops.total_float_ops))
The reason why the flops printed in the question being merely 18 is that, when generating the pb file, I set the input image shape as [None, None, 3]
. If I change it to, say [500, 500, 3]
, then the printed flops will be correct.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With