The official doc only states
>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
This doesn't show how to use the metric with the framework.
My attempt (methods are not complete and only show relevant parts):
def __init__(...):
self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)
def validation_step(self, batch, batch_index):
...
log_probs = self.forward(orig_batch)
loss = self._criterion(log_probs, label_batch)
self.val_confusion.update(log_probs, label_batch)
self.log('validation_confusion_step', self.val_confusion, on_step=True, on_epoch=False)
def validation_step_end(self, outputs):
return outputs
def validation_epoch_end(self, outs):
self.log('validation_confusion_epoch', self.val_confusion.compute())
After the 0th epoch, this gives
Traceback (most recent call last):
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 521, in train
self.train_loop.run_training_epoch()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 588, in run_training_epoch
self.trainer.run_evaluation(test_mode=False)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 613, in run_evaluation
self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 346, in log_evaluation_step_metrics
self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 350, in __log_result_step_metrics
cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 378, in update_logger_connector
batch_log_metrics = self.get_latest_batch_log_metrics()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 418, in get_latest_batch_log_metrics
batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics")
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in run_batch_from_func_name
results = [func(include_forked_originals=False) for func in results]
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in <listcomp>
results = [func(include_forked_originals=False) for func in results]
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 122, in get_batch_log_metrics
return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics",
*args, **kwargs)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in run_latest_batch_metrics_with_func_name
for dl_idx in range(self.num_dataloaders)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in <listcomp>
for dl_idx in range(self.num_dataloaders)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 100, in get_latest_from_func_name
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\core\step_result.py", line 298, in get_batch_log_metrics
result[dl_key] = self[k]._forward_cache.detach()
AttributeError: 'NoneType' object has no attribute 'detach'
It does pass the sanity validation check before training.
The failure happens on the return in validation_step_end
. Makes little sense to me.
The exact same method of using mertics works fine with accuracy.
How to get a correct confusion matrix?
You can report the figure using self.logger.experiment.add_figure(*tag*, *figure*)
.
The variable self.logger.experiment
is actually a SummaryWriter
(from PyTorch, not Lightning). This class has the method add_figure
(documentation).
You can use it as follows: (MNIST example)
def validation_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
loss = F.nll_loss(preds, y)
return { 'loss': loss, 'preds': preds, 'target': y}
def validation_epoch_end(self, outputs):
preds = torch.cat([tmp['preds'] for tmp in outputs])
targets = torch.cat([tmp['target'] for tmp in outputs])
confusion_matrix = pl.metrics.functional.confusion_matrix(preds, targets, num_classes=10)
df_cm = pd.DataFrame(confusion_matrix.numpy(), index = range(10), columns=range(10))
plt.figure(figsize = (10,7))
fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
plt.close(fig_)
self.logger.experiment.add_figure("Confusion matrix", fig_, self.current_epoch)
See a much better version
This took a lot of time to find.
This is the most minimal code I could paste that is still readable and reproducible.
I didn't want to put the entire model dataset and parameters here, as they are of no interest to readers of this question and are just noise.
That said, here is the required code for creating a confusion matrix per epoch and displaying in Tensorboard
This is a single frame for example:
import pytorch_lightning as pl
import seaborn as sn
import pandas as pd
import numpy as np
import io
import matplotlib.pyplot as plt
from PIL import Image
def __init__(self, config, trained_vae, latent_dim):
self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)
self.logger: Optional[TensorBoardLogger] = None
def forward(self, x):
...
return log_probs
def validation_step(self, batch, batch_index):
if self._config.dataset == "mnist":
orig_batch, label_batch = batch
orig_batch = orig_batch.reshape(-1, 28 * 28)
log_probs = self.forward(orig_batch)
loss = self._criterion(log_probs, label_batch)
self.val_confusion.update(log_probs, label_batch)
return {"loss": loss, "labels": label_batch}
def validation_step_end(self, outputs):
return outputs
def validation_epoch_end(self, outs):
tb = self.logger.experiment
# confusion matrix
conf_mat = self.val_confusion.compute().detach().cpu().numpy().astype(np.int)
df_cm = pd.DataFrame(
conf_mat,
index=np.arange(self._config.n_clusters),
columns=np.arange(self._config.n_clusters))
plt.figure()
sn.set(font_scale=1.2)
sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d')
buf = io.BytesIO()
plt.savefig(buf, format='jpeg')
buf.seek(0)
im = Image.open(buf)
im = torchvision.transforms.ToTensor()(im)
tb.add_image("val_confusion_matrix", im, global_step=self.current_epoch)
and the call to trainer
logger = TensorBoardLogger(save_dir=tb_logs_folder, name='Classifier')
trainer = Trainer(
deterministic=True,
max_epochs=10,
default_root_dir=classifier_checkpoints_path,
logger=logger,
gpus=1
)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With