Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use trained BERT model checkpoints for prediction?

I trained the BERT with SQUAD 2.0 and got the model.ckpt.data, model.ckpt.meta, model.ckpt.index (F1 score : 81) in the output directory along with predictions.json, etc. using the BERT-master/run_squad.py

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
  --do_train=True \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

I tried to copy the model.ckpt.meta, model.ckpt.index, model.ckpt.data to the $BERT_LARGE_DIR directory and changed the run_squad.py flags as follows to only predict the answer and not train using a dataset:

python run_squad.py \
  --vocab_file=$BERT_LARGE_DIR/vocab.txt \
  --bert_config_file=$BERT_LARGE_DIR/bert_config.json \
  --init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
  --do_train=False \
  --train_file=$SQUAD_DIR/train-v2.0.json \
  --do_predict=True \
  --predict_file=$SQUAD_DIR/dev-v2.0.json \
  --train_batch_size=24 \
  --learning_rate=3e-5 \
  --num_train_epochs=2.0 \
  --max_seq_length=384 \
  --doc_stride=128 \
  --output_dir=gs://some_bucket/squad_large/ \
  --use_tpu=True \
  --tpu_name=$TPU_NAME \
  --version_2_with_negative=True

It throws bucket directory/model.ckpt does not exist error.

How to utilize the checkpoints generated after training and use it for prediction?

like image 357
Jeeva Bharathi Avatar asked Oct 15 '22 13:10

Jeeva Bharathi


1 Answers

Usually, the trained checkpoints are created in the directory specified by --output_dir parameter while training. (Which is gs://some_bucket/squad_large/ in your case). Every checkpoint will have a number. You have to identify the biggest number; example: model.ckpt-12345. Now, set the --init_checkpoint parameter in your evaluation/prediction, using the output directory and the last saved checkpoint (The model with the highest number). (In your case, it shall be something like --init_checkpoint=gs://some_bucket/squad_large/model.ckpt-<highest number>)

like image 50
Ashwin Geet D'Sa Avatar answered Nov 01 '22 16:11

Ashwin Geet D'Sa