Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

getting the classification labels for torchvision's pretrained networks

Pytorch's torchvision package provides pre-trained neural networks for image classification. I've been using the following code to classify an image using Alexnet (note: some of this code is from this webpage):

from PIL import Image
import torch
from torchvision import transforms
from torchvision import models

# function to transform image
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225])])

# image
img = Image.open('/path/to/image.jpg')
img = transform(img)
img = torch.unsqueeze(img, 0)

# alexnet
alexnet = models.alexnet(pretrained=True)
alexnet.eval()
out = alexnet(img)
percents = torch.nn.functional.softmax(out, dim=1)[0] * 100
top5_vals, top5_inds = percents.topk(5)

There are 1,000 total classes, and the top5_inds variable gives me the indices of the top 5 classes. But how do I get the associated labels (e.g. snail, basketball, banana)? I can't seem to find any sort of list as part of Pytorch's documentation or the alexnet variable.

like image 741
Trevor Avatar asked Dec 11 '25 22:12

Trevor


1 Answers

Torchvision models are pretrained on the ImageNet dataset. Due to its comprehensiveness and size, ImageNet is the most commonly used dataset for pretraining & transfer learning. As you noted, it has 1000 classes. The complete class list can be searched, or you can refer to this listing on GitHub: https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a

like image 187
ccl Avatar answered Dec 14 '25 00:12

ccl



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!