Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Can't import frozen graph with BatchNorm layer

I have trained a Keras model based on this repo.

After the training I save the model as checkpoint files like this:

 sess=tf.keras.backend.get_session() 
 saver = tf.train.Saver()
 saver.save(sess, current_run_path + '/checkpoint_files/model_{}.ckpt'.format(date))

Then I restore the graph from the checkpoint files and freeze it using the standard tf freeze_graph script. When I want to restore the frozen graph I get the following error:

Input 0 of node Conv_BN_1/cond/ReadVariableOp/Switch was passed float from Conv_BN_1/gamma:0 incompatible with expected resource

How can I fix this issue?

Edit: My problem is related to this question. Unfortunately, I can't use the workaround.

Edit 2: I have opened an issue on github and created a gist to reproduce the error. https://github.com/keras-team/keras/issues/11032

like image 580
ninja Avatar asked Aug 15 '18 11:08

ninja


1 Answers

Just resolved the same issue. I connected this few answers: 1, 2, 3 and realized that issue originated from batchnorm layer working state: training or learning. So, in order to resolve that issue you just need to place one line before loading your model:

keras.backend.set_learning_phase(0)

Complete example, to export model

import tensorflow as tf
from tensorflow.python.framework import graph_io
from tensorflow.keras.applications.inception_v3 import InceptionV3


def freeze_graph(graph, session, output):
    with graph.as_default():
        graphdef_inf = tf.graph_util.remove_training_nodes(graph.as_graph_def())
        graphdef_frozen = tf.graph_util.convert_variables_to_constants(session, graphdef_inf, output)
        graph_io.write_graph(graphdef_frozen, ".", "frozen_model.pb", as_text=False)

tf.keras.backend.set_learning_phase(0) # this line most important

base_model = InceptionV3()

session = tf.keras.backend.get_session()

INPUT_NODE = base_model.inputs[0].op.name
OUTPUT_NODE = base_model.outputs[0].op.name
freeze_graph(session.graph, session, [out.op.name for out in base_model.outputs])

to load *.pb model:

from PIL import Image
import numpy as np
import tensorflow as tf

# https://i.imgur.com/tvOB18o.jpg
im = Image.open("/home/chichivica/Pictures/eagle.jpg").resize((299, 299), Image.BICUBIC)
im = np.array(im) / 255.0
im = im[None, ...]

graph_def = tf.GraphDef()

with tf.gfile.GFile("frozen_model.pb", "rb") as f:
    graph_def.ParseFromString(f.read())

graph = tf.Graph()

with graph.as_default():
    net_inp, net_out = tf.import_graph_def(
        graph_def, return_elements=["input_1", "predictions/Softmax"]
    )
    with tf.Session(graph=graph) as sess:
        out = sess.run(net_out.outputs[0], feed_dict={net_inp.outputs[0]: im})
        print(np.argmax(out))
like image 145
Ivan Talalaev Avatar answered Sep 20 '22 18:09

Ivan Talalaev