Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I extract features from a torchvision VisitionTransfomer (ViT)?

In order to use features from a pretrained VisionTransformer for a downstream task, I'd like to extract features. How do I extract features for example using a vit_b_16 from torchvision? The output should be 768 dimensional features for each image.

Similar as done using CNNs, I was just trying to remove the output layer and pass the input through the remaining layers:

    from torch import nn

    from torchvision.models.vision_transformer import vit_b_16
    from torchvision.models import ViT_B_16_Weights
    
    from PIL import Image as PIL_Image

    vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)
    modules = list(vit.children())[:-1]
    feature_extractor = nn.Sequential(*modules)

    preprocessing = ViT_B_16_Weights.DEFAULT.transforms()

    img = PIL_Image.open("example.png")
    img = preprocessing(img)

    feature_extractor(img)

This leads however to an exception:

RuntimeError: The size of tensor a (14) must match the size of tensor b (768) at non-singleton dimension 2
like image 351
mitja Avatar asked Oct 20 '25 12:10

mitja


1 Answers

Looking at the forward function in the source code of VisionTransformer and this helpful forum post, I managed to extract the features in the following way:


    from torch import nn

    from torchvision.models.vision_transformer import vit_b_16
    from torchvision.models import ViT_B_16_Weights

    from PIL import Image as PIL_Image

    vit = vit_b_16(weights=ViT_B_16_Weights.DEFAULT)

    preprocessing = ViT_B_16_Weights.DEFAULT.transforms()

    img = PIL_Image.open("example.png")
    img = preprocessing(img)

    # Add batch dimension
    img = img.unsqueeze(0)

    feats = vit._process_input(img)

    # Expand the CLS token to the full batch
    batch_class_token = vit.class_token.expand(img.shape[0], -1, -1)
    feats = torch.cat([batch_class_token, feats], dim=1)

    feats = vit.encoder(feats)

    # We're only interested in the representation of the CLS token that we appended at position 0
    feats = feats[:, 0]

    print(feats.shape)

Which correctly returns:

torch.Size([1, 768])

Edit: Depending on the downstream task, it might be better to average the features for all patches instead of taking the features from the CLS token: feats_avg = feats[:, 1:].mean(dim=1).

like image 178
mitja Avatar answered Oct 23 '25 08:10

mitja