我想显示一个图像。它使用a加载ImageLoader
并存储在PyTorch中Tensor
。
当我尝试通过显示它时plt.imshow(image)
,我得到:
TypeError: Invalid dimensions for image data
该.shape
张量是:
torch.Size([3, 244, 244])
如何显示此PyTorch张量中包含的图像?
给定一个Tensor
代表图像的图像,请使用.permute()
:
plt.imshow( tensor_image.permute(1, 2, 0) )
注意:permute
不会复制或分配内存,也不会。 from_numpy()
如您所见,matplotlib
即使不转换为numpy
数组也可以正常工作。但是PyTorch张量(“图像张量”)是第一个通道,因此要与它们一起使用,matplotlib
您需要对其进行重塑:
码:
from scipy.misc import face import matplotlib.pyplot as plt import torch np_image = face() print(type(np_image), np_image.shape) tensor_image = torch.from_numpy(np_image) print(type(tensor_image), tensor_image.shape) # reshape to channel first: tensor_image = tensor_image.view(tensor_image.shape[2], tensor_image.shape[0], tensor_image.shape[1]) print(type(tensor_image), tensor_image.shape) # If you try to plot image with shape (C, H, W) # You will get TypeError: # plt.imshow(tensor_image) # So we need to reshape it to (H, W, C): tensor_image = tensor_image.view(tensor_image.shape[1], tensor_image.shape[2], tensor_image.shape[0]) print(type(tensor_image), tensor_image.shape) plt.imshow(tensor_image) plt.show()
输出:
(768, 1024, 3) torch.Size([768, 1024, 3]) torch.Size([3, 768, 1024]) torch.Size([768, 1024, 3])