Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Trouble with PyTorch's 'ToPILImage'

Why does this not work?

import torchvision.transforms.functional as tf
from torchvision import transforms
pic = np.random.randint(0, 255+1, size=28*28).reshape(28, 28)
pic = pic.astype(int)
plt.imshow(pic)
t = transforms.ToPILImage()
t(pic.reshape(28, 28, 1))
# tf.to_pil_image(pic.reshape(28, 28, 1))

A beautiful random picture is plotted by matplotlib, but no matter what datatype I chose for my NumPy ndarray, neither to_pil_image or ToPILImage work as expected.

The docs have this to say:

Converts a tensor ... or a numpy ndarray of shape H x W x C to a PIL Image while preserving the value range. ... If the input has 1 channel, the mode is determined by the data type (i.e int , float , short ).

None of these datatypes work except for "short".

Everything else results in:

TypeError: Input type int64/float64 is not supported

thrown from torchvision/transforms/functional.py in to_pil_image().

Further, even though the short datatype will work for the stand alone code snippet I provided first, it breaks down when used inside a transform.Compose() called from a Dataset object's __getitem__:

choices = transforms.RandomChoice([transforms.RandomAffine(30),
                                   transforms.RandomPerspective()])

transform = transforms.Compose([transforms.ToPILImage(),
                                transforms.RandomApply([choices], 0.5),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])

trainset = MNIST('data/train.csv', transform=transform)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=4)


RuntimeError: DataLoader worker (pid 12917) is killed by signal: Floating point exception.
RuntimeError: DataLoader worker (pid(s) 12917) exited unexpectedly
like image 564
rocksNwaves Avatar asked Jun 28 '20 02:06

rocksNwaves


People also ask

What does ToPILImage do?

Convert a tensor or an ndarray to PIL Image. This transform does not support torchscript. Converts a torch.


1 Answers

The above answer worked for me only with the following change:

pic = pic.astype('uint8')

hope it works for you.

like image 98
Neela Avatar answered Oct 31 '22 07:10

Neela