Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to continue training an object detection model using Tensorflow Object Detection API?

I'm using Tensorflow Object Detection API to train an object detection model using transfer learning. Specifically, I'm using ssd_mobilenet_v1_fpn_coco from the model zoo, and using the sample pipeline provided, having of course replaced the placeholders with actual links to my training and eval tfrecords and labels.

I was able able to successfully train a model on my ~5000 images (and corresponding bounding boxes) using the above pipeline (I'm mainly using Google's ML Engine on TPU, if revelant).

Now, I prepared an additional ~2000 images, and would like continue training my model with those new images, without restarting from scratch (training the initial model took ~6h of TPU time). How can I do that?

like image 761
Simon Labrecque Avatar asked Nov 01 '18 15:11

Simon Labrecque


1 Answers

You have two options, in both you need to change the input_path of the train_input_reader of your new dataset:

  1. When specifying a checkpoint to fine-tune in the training configuration, specify the checkpoint of your trained model
train_config{
    fine_tune_checkpoint: <path_to_your_checkpoint>
    fine_tune_checkpoint_type: "detection"
    load_all_detection_checkpoint_vars: true
}
  1. Simply keep using the same configuration (except the train_input_reader) with the same model_dir of your previous model. That way, the API will create a graph and will check whether a checkpoint already exists in model_dir and fits the graph. If so - it will restore it and continue training it.

Edit: fine_tune_checkpoint_type was previously set as true by mistake, while it should be "detection" or "classification" in general, and "detection" in this specific case. Thanks Krish for noticing.

like image 74
netanel-sam Avatar answered Oct 06 '22 01:10

netanel-sam