I came across this problem while trying to solve another error. The first error (original problem) was that when I tried to restore a meta graph I would get Cannot find KeyError: "The name 'multi_rnn_cell_6' refers to an Operation not in the graph."
. In trying to create the MVCE for that problem I found this error.
A simple script which creates some ops, saves the meta graph and variables, and then tries to load the graph and variables fails. The problem seems to be related to the format TF is using.
import tensorflow as tf
import numpy as np
import os
import glob
class ImportIssue(object):
def __init__(self,load=False,model_scope = 'model',checkpoint='checkpoint'):
try:
os.makedirs(checkpoint)
except:
pass
save_file = os.path.join(checkpoint,'model')
print("Save file: {}".format(save_file))
graph = tf.Graph()
with graph.as_default():
if load:
# load model if requested
model_to_load = "{}.meta".format(tf.train.latest_checkpoint(checkpoint))
print("Loading model: {}".format(model_to_load))
rest = tf.train.import_meta_graph(model_to_load)
else:
# else create one
with tf.variable_scope(model_scope):
inputs = tf.placeholder(shape=(None,10,10),dtype=tf.float32)
cell = self._build_cell(10)
# this cell is failing to be fond
#print(cell.name)
rnn,state = tf.nn.dynamic_rnn(cell,inputs,dtype=tf.float32)
train_op = self._build_training_op(inputs,rnn)
saver = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES), max_to_keep=1)
with tf.Session(graph=graph) as sess:
if load:
rest.restore(sess, model_to_load)
else:
sess.run(tf.global_variables_initializer())
sess.run(train_op,feed_dict={inputs:np.random.normal(size=[3,10,10])})
saver.save(sess, save_file)
print("Saved model and graph")
print("Files in checkpoint dir: {}".format(glob.glob("{}/*".format(checkpoint))))
def _build_cell(self,size):
with tf.variable_scope("decoder"):
cells = []
cells.append(tf.nn.rnn_cell.GRUCell(size,activation=tf.nn.tanh))
for res_block_i in range(1):
res_block = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(size, use_peepholes=True) for i in range(2)])
res_block = tf.nn.rnn_cell.ResidualWrapper(res_block)
res_block = tf.nn.rnn_cell.DropoutWrapper(res_block, input_keep_prob = 1.0,
output_keep_prob = 0.5, state_keep_prob = 0.5,
variational_recurrent = True, dtype=tf.float32)
cells.append(res_block)
cell = tf.nn.rnn_cell.MultiRNNCell(cells)
return cell
def _build_training_op(self,inputs,rnn):
o = tf.train.AdamOptimizer(1e-3)
loss = tf.reduce_mean(tf.square(inputs - rnn))
return o.minimize(loss)
if __name__ == '__main__':
ImportIssue()
ImportIssue(load=True)
Prints
Saved model and graph
Files in checkpoint dir: ['checkpoint/model.data-00000-of-00001', 'checkpoint/model.meta', 'checkpoint/checkpoint', 'checkpoint/model.index']
Save file: checkpoint/model
Loading model: checkpoint/model.meta
The error is:
tensorflow.python.framework.errors_impl.DataLossError: Unable to open table file checkpoint/model.meta: Data loss: not an sstable (bad magic number): perhaps your file is in a different file format and you need to use a different restore operator?
Python 3.6 Fedora 64bit Linux TF 1.4
Yep checkpoint has to be specified without the .data-00000-of-00001 that seems to be added to end of all checkpoints created in the V2 tf graph save methods.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With