Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to calculate the flops of a tensorflow model loaded from pb file

I have a model saved in a pb file. I hope to calculate the flops of it. My example code is as follow:

import tensorflow as tf
import sys
from tensorflow.python.platform import gfile

from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.util import compat

pb_file = 'themodel.pb'

run_meta = tf.RunMetadata()
with tf.Session() as sess:
    print("load graph")
    with gfile.FastGFile(pb_path,'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        sess.graph.as_default()
        tf.import_graph_def(graph_def, name='')
        flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
            options=tf.profiler.ProfileOptionBuilder.float_operation())
        print("test flops:{:,}".format(flops.total_float_ops))

The print information is strange. My model has tens of layers, but it reports only 18 flops in the printed information. I'm quite sure the model is correctly loaded because if I try to print the names of every layer as follows:

print([n.name for n in tf.get_default_graph().as_graph_def().node])

The print information shows exactly the right network.

What's wrong with my code?

Thank you!

like image 679
pfc Avatar asked Nov 07 '22 02:11

pfc


1 Answers

I think I find the reason and solution for my question. The following code can print the flops of the given pb file.

import os
import tensorflow as tf
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import importer

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

pb_path = 'mymodel.pb'

run_meta = tf.RunMetadata()
with tf.Graph().as_default():
    output_graph_def = graph_pb2.GraphDef()
    with open(pb_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = importer.import_graph_def(output_graph_def, name="")
        print('model loaded!')
    all_keys = sorted([n.name for n in tf.get_default_graph().as_graph_def().node])
    # for k in all_keys:
    #   print(k)

    with tf.Session() as sess:
        flops = tf.profiler.profile(tf.get_default_graph(), run_meta=run_meta,
            options=tf.profiler.ProfileOptionBuilder.float_operation())
        print("test flops:{:,}".format(flops.total_float_ops))

The reason why the flops printed in the question being merely 18 is that, when generating the pb file, I set the input image shape as [None, None, 3]. If I change it to, say [500, 500, 3], then the printed flops will be correct.

like image 168
pfc Avatar answered Nov 15 '22 08:11

pfc