How do I display a single image in PyTorch?
PythonImage ProcessingPytorchPython Problem Overview
I want to display a single image loaded using ImageLoader
and stored in a PyTorch Tensor
. When I try to display it via plt.imshow(image)
I get:
TypeError: Invalid dimensions for image data
The .shape
of the tensor is:
torch.Size([3, 244, 244])
How do I display a PyTorch tensor as an image?
Python Solutions
Solution 1 - Python
Given a Tensor
representing the image, use .permute()
to put the channels as the last dimension:
plt.imshow( tensor_image.permute(1, 2, 0) )
Note: permute
does not copy or allocate memory, and from_numpy()
doesn't either.
Solution 2 - Python
As you can see matplotlib
works fine even without conversion to numpy
array. But PyTorch Tensors ("Image tensors") are channel first, so to use them with matplotlib
you need to reshape it:
Code:
from scipy.misc import face
import matplotlib.pyplot as plt
import torch
np_image = face()
print(type(np_image), np_image.shape)
tensor_image = torch.from_numpy(np_image)
print(type(tensor_image), tensor_image.shape)
# reshape to channel first:
tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1])
print(type(tensor_image), tensor_image.shape)
# If you try to plot image with shape (C, H, W)
# You will get TypeError:
# plt.imshow(tensor_image)
# So we need to reshape it to (H, W, C):
tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0])
print(type(tensor_image), tensor_image.shape)
plt.imshow(tensor_image)
plt.show()
Output:
<class 'numpy.ndarray'> (768, 1024, 3)
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
<class 'torch.Tensor'> torch.Size([3, 768, 1024])
<class 'torch.Tensor'> torch.Size([768, 1024, 3])
Solution 3 - Python
Given the image is loaded as described and stored in the variable image
:
plt.imshow(transforms.ToPILImage()(image), interpolation="bicubic")
#transforms.ToPILImage()(image).show() # Alternatively
Or as Soumith suggested:
> def show(img): > npimg = img.numpy() > plt.imshow(np.transpose(npimg, (1, 2, 0)), interpolation='nearest')
Solution 4 - Python
A complete example given an image pathname img_path
:
from PIL import Image
image = Image.open(img_path)
plt.imshow(transforms.ToPILImage()(transforms.ToTensor()(image)), interpolation="bicubic")
Note that transforms.*
return a class, which is why the funky bracketing.
Solution 5 - Python
PyTorch modules processing image data expect tensors in the format C × H × W.1
Whereas PILLow and Matplotlib expect image arrays in the format H × W × C.2
You can easily convert tensors to/from this format with a TorchVision transform:
from torchvision import transforms.functional as F
F.to_pil_image(image_tensor)
Or by directly permuting the axes:
image_tensor.permute(1,2,0)
-
> PyTorch modules dealing with image data require tensors to be laid out as C × H × W : channels, height, and width, respectively.
-
> Note how we have to use
permute
to change the order of the axes from C × H × W to H × W × C to match what Matplotlib expects. > - Deep Learning with PyTorch
Solution 6 - Python
Solution 7 - Python
I've written a simple function to visualize the pytorch tensor using matplotlib.
import numpy as np
import matplotlib.pyplot as plt
import torch
def show(*imgs):
'''
input imgs can be single or multiple tensor(s), this function uses matplotlib to visualize.
Single input example:
show(x) gives the visualization of x, where x should be a torch.Tensor
if x is a 4D tensor (like image batch with the size of b(atch)*c(hannel)*h(eight)*w(eight), this function splits x in batch dimension, showing b subplots in total, where each subplot displays first 3 channels (3*h*w) at most.
if x is a 3D tensor, this function shows first 3 channels at most (in RGB format)
if x is a 2D tensor, it will be shown as grayscale map
Multiple input example:
show(x,y,z) produces three windows, displaying x, y, z respectively, where x,y,z can be in any form described above.
'''
img_idx = 0
for img in imgs:
img_idx +=1
plt.figure(img_idx)
if isinstance(img, torch.Tensor):
img = img.detach().cpu()
if img.dim()==4: # 4D tensor
bz = img.shape[0]
c = img.shape[1]
if bz==1 and c==1: # single grayscale image
img=img.squeeze()
elif bz==1 and c==3: # single RGB image
img=img.squeeze()
img=img.permute(1,2,0)
elif bz==1 and c > 3: # multiple feature maps
img = img[:,0:3,:,:]
img = img.permute(0, 2, 3, 1)[:]
print('warning: more than 3 channels! only channels 0,1,2 are preserved!')
elif bz > 1 and c == 1: # multiple grayscale images
img=img.squeeze()
elif bz > 1 and c == 3: # multiple RGB images
img = img.permute(0, 2, 3, 1)
elif bz > 1 and c > 3: # multiple feature maps
img = img[:,0:3,:,:]
img = img.permute(0, 2, 3, 1)[:]
print('warning: more than 3 channels! only channels 0,1,2 are preserved!')
else:
raise Exception("unsupported type! " + str(img.size()))
elif img.dim()==3: # 3D tensor
bz = 1
c = img.shape[0]
if c == 1: # grayscale
img=img.squeeze()
elif c == 3: # RGB
img = img.permute(1, 2, 0)
else:
raise Exception("unsupported type! " + str(img.size()))
elif img.dim()==2:
pass
else:
raise Exception("unsupported type! "+str(img.size()))
img = img.numpy() # convert to numpy
img = img.squeeze()
if bz ==1:
plt.imshow(img, cmap='gray')
# plt.colorbar()
# plt.show()
else:
for idx in range(0,bz):
plt.subplot(int(bz**0.5),int(np.ceil(bz/int(bz**0.5))),int(idx+1))
plt.imshow(img[idx], cmap='gray')
else:
raise Exception("unsupported type: "+str(type(img)))