我是pytorch的新手,想了解一些东西。
我正在按以下方式加载MNIST:
transform_train = transforms.Compose( [transforms.ToTensor(), transforms.Resize(size, interpolation=2), # transforms.Grayscale(num_output_channels=1), transforms.RandomHorizontalFlip(p=0.5), transforms.Normalize((mean), (std))]) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
但是,当我探索数据集时,即trainloader.dataset.train_data[0]
,我得到的张量为[0,255],形状为(28,28)。
我想念什么?这是因为转换没有直接应用于数据加载器,而是仅在运行时?否则我该如何浏览我的数据?
调用的__getitem__
方法时应用转换Dataset
。例如,查看数据集类的__getitem__
方法MNIST
:https : //github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62
def __getitem__(self, index): """ Args: index (int): Index Returns: tuple: (image, target) where target is index of the target class. """ img, target = self.data[index], self.targets[index] # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.numpy(), mode='L') if self.transform is not None: img = self.transform(img) if self.target_transform is not None: target = self.target_transform(target) return img, target
__getitem__
当您MNIST
为训练集索引实例时,将调用该方法,例如:
trainset[0]
有关更多信息__getitem__
:https : //docs.python.org/3.6/reference/datamodel.html#object。getitem
为什么原因Resize
和RandomHorizontalFlip
前应ToTensor
是它们作用于PIL图像和所有数据集中在Pytorch一致性负载数据作为PIL Image
第一秒。实际上,您可以在这里看到他们通过以下方式强制执行该行为:
img = Image.fromarray(img.numpy(), mode='L')
一旦你有PIL Image
相应的指数,在变换被施加
if self.transform is not None: img = self.transform(img)
ToTensor
将PIL Image
a 转换为a torch.Tensor
并Normalize
减去平均值,然后除以您提供的标准差。
最终,一些转换将应用于
if self.target_transform is not None: target = self.target_transform(target)
最后,返回处理后的图像和处理后的标签。所有这些都在一个trainset[key]
电话中发生。
import torch from torchvision.transforms import * from torchvision.datasets import MNIST from torch.utils.data import DataLoader transform_train = Compose([Resize(28, interpolation=2), RandomHorizontalFlip(p=0.5), ToTensor(), Normalize([0.], [1.])]) trainset = MNIST(root='./data', train=True, download=True, transform=transform_train) trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2) print(trainset[0][0].size(), trainset[0][0].min(), trainset[0][0].max())
表演
(torch.Size([1, 28, 28]), tensor(0.), tensor(1.))