Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Concatenating two saved models in tensorflow 1.13 [duplicate]

Tags:

tensorflow

I've trained a DCGAN model and would now like to load it into a library that visualizes the drivers of neuron activation through image space optimization.

The following code works, but forces me to work with (1, width, height, channels) images when doing subsequent image analysis, which is a pain (the library assumptions about the shape of network input).

# creating TensorFlow session and loading the model
graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

new_saver = tf.train.import_meta_graph(model_fn)
new_saver.restore(sess, './')

I'd like to change the input_map, After reading the source, I expected this code to work:

graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

t_input = tf.placeholder(np.float32, name='images') # define the input tensor
t_preprocessed = tf.expand_dims(t_input, 0)

new_saver = tf.train.import_meta_graph(model_fn, input_map={'images': t_input})
new_saver.restore(sess, './')

But got an error:

ValueError: tf.import_graph_def() requires a non-empty name if input_map is used.

When the stack gets down to tf.import_graph_def() the name field is set to import_scope, so I tried the following:

graph = tf.Graph()
sess = tf.InteractiveSession(graph=graph)

t_input = tf.placeholder(np.float32, name='images') # define the input tensor
t_preprocessed = tf.expand_dims(t_input, 0)

new_saver = tf.train.import_meta_graph(model_fn, input_map={'images': t_input}, import_scope='import')
new_saver.restore(sess, './')

Which netted me the following KeyError:

KeyError: "The name 'gradients/discriminator/minibatch/map/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/RefEnter:0' refers to a Tensor which does not exist. The operation, 'gradients/discriminator/minibatch/map/while/TensorArrayWrite/TensorArrayWriteV3_grad/TensorArrayReadV3/RefEnter', does not exist in the graph."

If I set 'import_scope', I get the same error whether or not I set 'input_map'.

I'm not sure where to go from here.

like image 430
Sevenless Avatar asked Nov 15 '22 20:11

Sevenless


1 Answers

In the newer version of tensorflow>=1.2.0, the following step works fine.

t_input = tf.placeholder(np.float32, shape=[None, width, height, channels], name='new_input') # define the input tensor

# here you need to give the name of the original model input placeholder name
# For example if the model has input as; input_original=  tf.placeholder(tf.float32, shape=(1, width, height, channels, name='original_placeholder_name'))
new_saver = tf.train.import_meta_graph(/path/to/checkpoint_file.meta, input_map={'original_placeholder_name:0':  t_input})
new_saver.restore(sess, '/path/to/checkpointfile')
like image 196
Ishant Mrinal Avatar answered Jun 18 '23 21:06

Ishant Mrinal