Pytorch's torchvision package provides pre-trained neural networks for image classification. I've been using the following code to classify an image using Alexnet (note: some of this code is from this webpage):
from PIL import Image
import torch
from torchvision import transforms
from torchvision import models
# function to transform image
transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
# image
img = Image.open('/path/to/image.jpg')
img = transform(img)
img = torch.unsqueeze(img, 0)
# alexnet
alexnet = models.alexnet(pretrained=True)
alexnet.eval()
out = alexnet(img)
percents = torch.nn.functional.softmax(out, dim=1)[0] * 100
top5_vals, top5_inds = percents.topk(5)
There are 1,000 total classes, and the top5_inds variable gives me the indices of the top 5 classes. But how do I get the associated labels (e.g. snail, basketball, banana)? I can't seem to find any sort of list as part of Pytorch's documentation or the alexnet variable.
Torchvision models are pretrained on the ImageNet dataset. Due to its comprehensiveness and size, ImageNet is the most commonly used dataset for pretraining & transfer learning. As you noted, it has 1000 classes. The complete class list can be searched, or you can refer to this listing on GitHub: https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With