I trained a model with fastai.tabular
. Now, I have a fitted learner. Ultimately, models are there to be applied to new data and not just to be fitted on training set and evaluated on test set etc. I tried different things all resulting in errors or some weirdness. Is there a way to apply a model trained with fastai to previously unavailable data? Or do I have to train the model again and again and feed new test data in? That does not seem likely.
df_test = pd.read_parquet('generated_test.parquet').head(100)
test_data = TabularList.from_df(df_test, cat_names=cat_names, cont_names=cont_names)
prediction = learn.predict(test_data)
KeyError: 'atomic_distance'
atomic_distance
is the name of a column present in both the training and test data and also contained in cont_names
.
prediction = learn.get_preds(kaggle_test_data)
This does something, but it returns something weird:
[tensor([[136.0840],
[ -2.0286],
[ -2.0944],
...,
[135.6165],
[ 2.7626],
[ 8.0316]]),
tensor([ 84.8076, -11.2570, -11.2548, ..., 81.0491, 0.8874, 4.1235])]
The documentation says:
Docstring: Return predictions and targets on
ds_type
dataset.
This is new, unlabeled data. I don't know why the returning object should have labels. Where are they coming from? Also the size does not make sense. I am expecting something with 100 values.
I found a way by passing in the dataframe row by row:
prediction = [float(learn.predict(df_test.loc[i])[0].data) for i in df_test.index]
There is also the method predict_batch
available, but it does seem to accept datafames. Are there better ways to do this?
I use:
data_test = (TabularList.from_df(DF_TEST, path=path, cat_names=cat_names,cont_names=cont_vars, procs=procs)
.split_none()
.label_from_df(cols=dep_var))
data_test.valid = data_test.train
data_test=data_test.databunch()
learn.data.valid_dl = data_test.valid_dl
pred = learn.get_preds(ds_type=DatasetType.Valid)[0]
Where DF_TEST
is the test dataframe, dep_var
is the depended variable, and learn
is your model.
To be honest, it works most of the times, other times it give weird error and then I have to iterate each row to get prediction.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With