I'm trying to reload a DistilBertForSequenceClassification model I've fine-tuned and use that to predict some sentences into their appropriate labels (text classification).
In google Colab, after successfully training the BERT model, I downloaded it after saving:
trainer.train()
trainer.save_model("distilbert_classification")
The downloaded model has three files: config.json, pytorch_model.bin, training_args.bin.
I moved them encased in a folder named 'distilbert_classification' somewhere in my google drive.
afterwards, I reloaded the model in a different Colab notebook:
reloadtrainer = DistilBertForSequenceClassification.from_pretrained('google drive directory/distilbert_classification')
Up to this point, I have succeeded without any errors.
However, how to I use this reloaded model (the 'reloadtrainer' object) to actually make the predictions on sentences? What is the code I need to use afterwards? I tried
reloadtrainer .predict("sample sentence") but it doesn't work. Would appreciate any help!
Remember that you also need to tokenize the input to your model, just like in the training phase. Merely feeding a sentence to the model will not work (unless you use pipelines() but that's another discussion).
You may use an AutoModelForSequenceClassification() and AutoTokenizer() to make things easier.
Note that the way I am saving the model is via model.save_pretrained("path_to_model") rather than model.save().
One possible approach could be the following (say you trained with uncased distilbert):
model = AutoModelForSequenceClassification.from_pretrained("path_to_model")
# Replace with whatever tokenizer you used
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased", use_fast=True)
input_text = "This is the text I am trying to classify."
tokenized_text = tokenizer(input_text,
truncation=True,
is_split_into_words=False,
return_tensors='pt')
outputs = model(tokenized_text["input_ids"])
predicted_label = outputs.logits.argmax(-1)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With