Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tensorflow: run model evaluation over multiple checkpoints

In my current project I train a model and save checkpoints every 100 iteration steps. The checkpoint files are all saved to the same directory (model.ckpt-100, model.ckpt-200 , model.ckpt-300 etc). And after that I would like to evalute the model based on validation data for all the saved checkpoints, not just the latest one.

Currently my piece of code for restoring the checkpoint file looks like this:

ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
ckpt_list = saver.last_checkpoints
print(ckpt_list)
if ckpt and ckpt.model_checkpoint_path:
    print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)
    # extract global_step from it.
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    print('Succesfully loaded model from %s at step=%s.' %
            (ckpt.model_checkpoint_path, global_step))
else:
    print('No checkpoint file found')
    return

However, this restores only the latest saved checkpoint file. So how do I write a loop over all saved checkpoint files? I tried to get a list of the checkpoint files with saver.last_checkpoints, however, the returned list is empty.

Any help would be highly appreciated, thanks in advance!

like image 634
TheJude Avatar asked Mar 10 '23 17:03

TheJude


2 Answers

Fastest solution:

tensor2tensor has a module utils with a script avg_checkpoints.py that saves the averaged weights in a new checkpoint. Let us say you have a list of checkpoints that you want to average. You have 2 options for usage:

  1. From command line

    TRAIN_DIR=path_to_your_model_folder
    FNC_PATH=path_to_tensor2tensor+'/utils/avg.checkpoints.py'
    CKPTS=model.ckpt-10000,model.ckpt-20000,model.ckpt-100000
    
    python3 $FNC_PATH --prefix=$TRAIN_DIR --checkpoints=$CKPTS \ 
        --output_path="${TRAIN_DIR}averaged.ckpt"
    
  2. From your own code (use os.system):

    import os
    os.system(
        "python3 "+FNC_DIR+" --prefix="+TRAIN_DIR+" --checkpoints="+CKPTS+
        " --output_path="+TRAIN_DIR+"averaged.ckpt"
    )
    

As an alternative to specifying a list of checkpoints and using the --checkpoints argument, you can just use --num_checkpoints=10 to average the last 10 checkpoints.

If you do not want a dependency on tensor2tensor:

Here is a code snippet that does not rely on tensor2tensor, but can still average a variable number of checkpoints (as opposed to ted's answer). Assume steps is a list of checkpoints that should be merged (e.g. [10000, 20000, 30000, 40000]).

Then:

# Restore all sessions and save the weight matrices
values = []
for step in steps:
    tf.reset_default_graph()
    path = model_path+'/model.ckpt-'+str(step)
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(path+'.meta')
        saver.restore(sess, path)
        values.append(sess.run(tf.all_variables()))

# Average weights
variables = tf.all_variables()
all_assign = []
for ind, var in enumerate(variables):
    weights = np.concatenate(
        [np.expand_dims(w[ind],axis=0)  for w in values],
        axis=0
    )
    all_assign.append(tf.assign(var, np.mean(weights, axis=0))

Then you can proceed, however you prefer, e.g. saving the averaged checkpoint:

# Now save the new values into a separate checkpoint
with tf.Session() as sess_test:
    sess_test.run(all_assign)
    saver = tf.train.Saver() 
    saver.save(sess_test, model_path+'/average_'+str(num_checkpoints))
like image 56
dopexxx Avatar answered Apr 06 '23 18:04

dopexxx


You can iterate through the files in the directory :

import os

dir_path = './' #change that to wherever your files are
ckpt_files = [f for f in os.listdir(dir_path) if os.path.isfile(
    os.path.join(dir_path, f)) and 'ckpt' in f]

for ckpt_file in ckpt_files:
    saver.restore(sess, dir_path + ckpt_file)
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    print('Succesfully loaded model from %s at step=%s.' %
          (ckpt.model_checkpoint_path, global_step))

    # Do your thing

add more conditions in the above list comprehension to be more selective like : and 'meta' not in f and so on depending on what's in that dir and the saver version you have

like image 36
ted Avatar answered Apr 06 '23 18:04

ted