Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to set the checkpoint for fine-tuning

I found the loss when I retrain the model(ssd_mobilenetv2) from model_zoo is very large at the begining of training, While the accuracy on validation_set is good. Training log as below:

The log couldn't be from the trained model. I doubt it doesn't load the checkpoint to do the fine-tune. Please help me how to do the fine-tune with the trained model on the same dataset. I didn't modify the network structure at all.

I set the checkpoint path in pipeline.config as below: fine_tune_checkpoint:"//ssd_mobilenet_v2_coco_2018_03_29/model.ckpt" If I set the model_dir as my downloaded directory, It wouldn't train since the global_train_step is larger than max_step. Then I enlarge the max_step, I can see the log of restoring the parameter from checkpoint. But it would meet error that couldn't restore some parameter. So I set the model_dir to a empty directory. It could train normally but the loss in step0 would be very large. And the validation result is very bad

in pipeline.config

fine_tune_checkpoint: "/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt"
num_steps: 200000
fine_tune_checkpoint_type: "detection"

train script

model_dir = '/ssd_mobilenet_v2_coco_2018_03_29/retrain0524

pipeline_config_path = '/ssd_mobilenet_v2_coco_2018_03_29/pipeline.config'

checkpoint_dir = '/ssd_mobilenet_v2_coco_2018_03_29/model.ckpt'

num_train_steps = 300000
config = tf.estimator.RunConfig(model_dir=model_dir)
train_and_eval_dict = model_lib.create_estimator_and_inputs(
    run_config=config,
    hparams=model_hparams.create_hparams(hparams_overrides),
    pipeline_config_path=pipeline_config_path,    
    sample_1_of_n_eval_examples=sample_1_of_n_eval_examples,
    sample_1_of_n_eval_on_train_examples=(sample_1_of_n_eval_on_train_examples))
estimator = train_and_eval_dict['estimator']
train_input_fn = train_and_eval_dict['train_input_fn']
eval_input_fns = train_and_eval_dict['eval_input_fns']
eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
predict_input_fn = train_and_eval_dict['predict_input_fn']
train_steps = train_and_eval_dict['train_steps']

train_spec, eval_specs = model_lib.create_train_and_eval_specs(
        train_input_fn,
        eval_input_fns,
        eval_on_train_input_fn,
        predict_input_fn,
        train_steps,
        eval_on_train_data=False)

tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])

INFO:tensorflow:loss = 356.25497, step = 0 INFO:tensorflow:global_step/sec: 1.89768 INFO:tensorflow:loss = 11.221423, step = 100 (52.700 sec) INFO:tensorflow:global_step/sec: 2.21685 INFO:tensorflow:loss = 10.329516, step = 200 (45.109 sec)

like image 627
Fiery Avatar asked Jan 17 '26 23:01

Fiery


1 Answers

If the initial training loss is 400, the model most likely is restored from a checkpoint successfully, just not all the same as the checkpoint.

Here is the restore_map function of ssd models, note that even if you set fine_tune_checkpoint_type : detection and even provided with exactly the same model's checkpoint, still only the variables in the feature_extractor scope are restored. To restore as much variables from the checkpoint as possible, you will have to set load_all_detection_checkpoint_vars: true in your config file.

def restore_map(self,
              fine_tune_checkpoint_type='detection',
              load_all_detection_checkpoint_vars=False):

if fine_tune_checkpoint_type not in ['detection', 'classification']:
  raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
      fine_tune_checkpoint_type))

if fine_tune_checkpoint_type == 'classification':
  return self._feature_extractor.restore_from_classification_checkpoint_fn(
      self._extract_features_scope)

if fine_tune_checkpoint_type == 'detection':
  variables_to_restore = {}
  for variable in tf.global_variables():
    var_name = variable.op.name
    if load_all_detection_checkpoint_vars:
      variables_to_restore[var_name] = variable
    else:
      if var_name.startswith(self._extract_features_scope):
        variables_to_restore[var_name] = variable

return variables_to_restore
like image 167
danyfang Avatar answered Jan 19 '26 19:01

danyfang



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!