Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

One hot encoding a segmented image using pytorch

I have a segmented image as a tensor of size [1,1,256,256]. The image is a binary segmented image. I want to one hot encode it to get an image of size [1,2,256,256]. I tried torch.nn.functional.one_hot(img, 2). But it gave me an image of size [1,256,256,2]. How do I get the desired tensor?

like image 715
BBloggsbott Avatar asked Mar 21 '26 02:03

BBloggsbott


1 Answers

Try to use transpose():

img_one_hot = torch.nn.functional.one_hot(img, 2).transpose(1, 4).squeeze(-1)

transpose(1, 4) - swaps 1st and 4th dimension, returning the tensor of the shape of [1, 2, 256, 256, 1], squeeze(-1) removes the last dim resulting in [1 , 2, 256, 256] shaped tensor.

like image 192
trsvchn Avatar answered Mar 22 '26 15:03

trsvchn



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!