I trained the BERT with SQUAD 2.0 and got the model.ckpt.data
, model.ckpt.meta
, model.ckpt.index
(F1 score : 81) in the output directory along with predictions.json
, etc. using the BERT-master/run_squad.py
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/bert_model.ckpt \
--do_train=True \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
I tried to copy the model.ckpt.meta
, model.ckpt.index
, model.ckpt.data
to the $BERT_LARGE_DIR
directory and changed the run_squad.py
flags as follows to only predict the answer and not train using a dataset:
python run_squad.py \
--vocab_file=$BERT_LARGE_DIR/vocab.txt \
--bert_config_file=$BERT_LARGE_DIR/bert_config.json \
--init_checkpoint=$BERT_LARGE_DIR/model.ckpt \
--do_train=False \
--train_file=$SQUAD_DIR/train-v2.0.json \
--do_predict=True \
--predict_file=$SQUAD_DIR/dev-v2.0.json \
--train_batch_size=24 \
--learning_rate=3e-5 \
--num_train_epochs=2.0 \
--max_seq_length=384 \
--doc_stride=128 \
--output_dir=gs://some_bucket/squad_large/ \
--use_tpu=True \
--tpu_name=$TPU_NAME \
--version_2_with_negative=True
It throws bucket directory/model.ckpt does not exist error.
How to utilize the checkpoints generated after training and use it for prediction?
Usually, the trained checkpoints are created in the directory specified by --output_dir
parameter while training. (Which is gs://some_bucket/squad_large/
in your case). Every checkpoint will have a number. You have to identify the biggest number; example: model.ckpt-12345
. Now, set the --init_checkpoint
parameter in your evaluation/prediction, using the output directory and the last saved checkpoint (The model with the highest number). (In your case, it shall be something like --init_checkpoint=gs://some_bucket/squad_large/model.ckpt-<highest number>
)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With