Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

output prediction of pytorch lightning model

This is potentially a very easy question. I just started with PyTorch lightning and can't figure out how to receive the output of my model after training.

I am interested in both predictions of y_train and y_test as an array of some sort (PyTorch tensor or NumPy array in a later step) to plot next to the labels using different scripts.

dataset = Dataset(train_tensor)
val_dataset = Dataset(val_tensor)
training_generator = torch.utils.data.DataLoader(dataset, **train_params)
val_generator = torch.utils.data.DataLoader(val_dataset, **val_params)
mynet = Net(feature_len)
trainer = pl.Trainer(gpus=0,max_epochs=max_epochs, logger=logger, progress_bar_refresh_rate=20, callbacks=[early_stop_callback], num_sanity_val_steps=0)
trainer.fit(mynet)

In my lightning module I have the functions:

def __init__(self, random_inputs):

def forward(self, x):

def train_dataloader(self):
    
def val_dataloader(self):

def training_step(self, batch, batch_nb):

def training_epoch_end(self, outputs):

def validation_step(self, batch, batch_nb):

def validation_epoch_end(self, outputs):

def configure_optimizers(self):

Do I need a specific predict function or is there any already implemented way I don't see?

like image 682
Tom S Avatar asked Jan 25 '23 12:01

Tom S


2 Answers

I disagree with these answers: OP's question appears to be focused on how he should use a model trained in lightning to get predictions in general, rather than for a specific step in the training pipeline. In which case, a user shouldn't need to go anywhere near a Trainer object - those are not intended to be used for general prediction and the answers above are therefore encouraging an anti-pattern (carrying a trainer object around with us every time we want to do some prediction) to anyone who reads these answers in the future.

Instead of using trainer, we can get predictions straight from the Lightning module that has been defined: if I have my (trained) instance of the lightning module model = Net(...) then using that model to get predictions on inputs x is achieved simply by calling model(x) (so long as the forward method has been implemented/overriden on the Lightning module - which is required).

In contrast, Trainer.predict() is not the intended means of obtaining predictions using your trained model in general. The Trainer API provides methods to tune, fit and test your LightningModule as part of your training pipeline, and it looks to me that the predict method is provided for ad-hoc predictions on separate dataloaders as part of less 'standard' training steps.

The OP's question (Do I need a specific predict function or is there any already implemented way I don't see?) implies that they're not familiar with the way that the forward() method works in PyTorch, but asks whether there's already a method for prediction that they can't see. A full answer therefore requires a further explanation of where the forward() method fits into the prediction process:

The reason model(x) works is because Lightning Modules are subclasses of torch.nn.Module and these implement a magic method called __call__() which means that we can call the class instance as if it were a function. __call__() in turn calls forward(), which is why we need to override that method in our Lightning module.

NB. because forward is only one piece of the logic called when we use model(x), it is always recommended to use model(x) instead of model.forward(x) for prediction unless you have a specific reason to deviate.

like image 139
UpstatePedro Avatar answered Jan 27 '23 02:01

UpstatePedro


You can use the predict method as well. Here is the example from the document. https://pytorch-lightning.readthedocs.io/en/latest/starter/introduction_guide.html

class LitMNISTDreamer(LightningModule):

    def forward(self, z):
        imgs = self.decoder(z)
        return imgs

    def predict_step(self, batch, batch_idx: int , dataloader_idx: int = None):
        return self(batch)


model = LitMNISTDreamer()
trainer.predict(model, datamodule) 
like image 24
sushmit Avatar answered Jan 27 '23 03:01

sushmit