Logo Questions Linux Laravel Mysql Ubuntu Git Menu

fastai learner requirements and batch prediction

I previously trained a resnet34 model using the fastai library, and have the weights.h5 file saved. With the latest version of fastai, do I still need to have non-empty train and valid folders in order to import my learner and predict on the test set?

Also, I’m currently looping through every test image and using learn.predict_array, but is there a way to predict in batches on a test folder?

Example of what I’m currently doing just to load/predict:

PATH = '/path/to/model/'
sz = 224
tfms = tfms_from_model(resnet34, sz, aug_tfms=transforms_side_on, max_zoom=1.1)
data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=64)
learn = ConvLearner.pretrained(arch, data, precompute=False)

imgs = sorted(glob(os.path.join(test_path, '*.jpg')))
preds = []
_,val_tfms = tfms_from_model(resnet34, 224)
for n, i in enumerate(imgs):
        im = val_tfms(open_image(i))[None]

There must be a cleaner way to do this by now, no?

like image 415
Austin Avatar asked Nov 11 '18 21:11


People also ask

What is batch prediction in machine learning?

Batch prediction is useful when you want to generate predictions for a set of observations all at once, and then take action on a certain percentage or number of the observations. Typically, you do not have a low latency requirement for such an application.

Video Answer

2 Answers

In fastai, you can now export and load a learner to do prediction on the test set without having to load a non empty training and validation set. To do that, you should use export method and load_learner function (both are defined in basic_train).

In your current situation, you might have to load your learner the old way (with a train/valid dataset), then export it and you'll be able to use load_learner to do your predictions on your test set.

I'll leave a link to the documentation :


This should clarify any follow up questions.

like image 181
Statistic Dean Avatar answered Oct 13 '22 15:10

Statistic Dean

data = ImageClassifierData.from_paths(PATH, tfms=tfms, bs=64)
learn = ConvLearner.pretrained(arch, data, precompute=False)

preds = learn.predict(is_test=True)
like image 28
Sunhwan Jo Avatar answered Oct 13 '22 16:10

Sunhwan Jo