当前位置:  开发笔记 > 编程语言 > 正文

转换不适用于数据集

如何解决《转换不适用于数据集》经验,为你挑选了1个好方法。

我是pytorch的新手,想了解一些东西。

我正在按以下方式加载MNIST:

transform_train = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Resize(size, interpolation=2),
     # transforms.Grayscale(num_output_channels=1),
     transforms.RandomHorizontalFlip(p=0.5),
     transforms.Normalize((mean), (std))])


trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

但是,当我探索数据集时,即trainloader.dataset.train_data[0],我得到的张量为[0,255],形状为(28,28)。

我想念什么?这是因为转换没有直接应用于数据加载器,而是仅在运行时?否则我该如何浏览我的数据?



1> iacolippo..:

调用的__getitem__方法时应用转换Dataset。例如,查看数据集类的__getitem__方法MNIST:https : //github.com/pytorch/vision/blob/master/torchvision/datasets/mnist.py#L62

def __getitem__(self, index):
    """
    Args:
        index (int): Index
    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], self.targets[index]

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img.numpy(), mode='L')

    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target

__getitem__当您MNIST为训练集索引实例时,将调用该方法,例如:

trainset[0]

有关更多信息__getitem__:https : //docs.python.org/3.6/reference/datamodel.html#object。getitem

为什么原因ResizeRandomHorizontalFlip前应ToTensor是它们作用于PIL图像和所有数据集中在Pytorch一致性负载数据作为PIL Image第一秒。实际上,您可以在这里看到他们通过以下方式强制执行该行为:

img = Image.fromarray(img.numpy(), mode='L')

一旦你有PIL Image相应的指数,在变换被施加

if self.transform is not None:
    img = self.transform(img)

ToTensorPIL Imagea 转换为a torch.TensorNormalize减去平均值,然后除以您提供的标准差。

最终,一些转换将应用于

if self.target_transform is not None:
    target = self.target_transform(target)

最后,返回处理后的图像和处理后的标签。所有这些都在一个trainset[key]电话中发生。

import torch
from torchvision.transforms import *
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

transform_train = Compose([Resize(28, interpolation=2),
                           RandomHorizontalFlip(p=0.5),
                           ToTensor(),
                           Normalize([0.], [1.])])

trainset = MNIST(root='./data', train=True, download=True,
                 transform=transform_train)
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)
print(trainset[0][0].size(), trainset[0][0].min(), trainset[0][0].max())

表演

(torch.Size([1, 28, 28]), tensor(0.), tensor(1.))

推荐阅读
雯颜哥_135
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有