当前位置:  开发笔记 > 编程语言 > 正文

Pytorch:图片标签

如何解决《Pytorch:图片标签》经验,为你挑选了1个好方法。

我正在使用31个类(Office数据集)进行图像分类。每个类都有一个文件夹。我有一个使用PyTorch编写的python脚本,该脚本使用加载数据集datasets.ImageFolder并为每个图像分配标签,然后进行训练。这是我用于加载数据的代码片段:

from torchvision import datasets, transforms
import torch

def load_training(root_path, dir, batch_size, kwargs):
    transform = transforms.Compose(
        [transforms.Resize([256, 256]),
         transforms.RandomCrop(224),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor()])
    data = datasets.ImageFolder(root=root_path + dir, transform=transform)
    train_loader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs)
    return train_loader

该代码将占用每个文件夹,并为该文件夹中的所有图像分配相同的标签。有什么方法可以找到将哪个标签分配给哪个图像/图像文件夹?



1> Jan..:

ImageFolder类具有一个属性class_to_idx,该属性是将类的名称映射到索引(标签)的字典。因此,您可以使用访问类,data.classes对于每个类,请使用获取标签data.class_to_idx

供参考:https : //github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py

推荐阅读
手机用户2402852387
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有