当前位置:  开发笔记 > 编程语言 > 正文

在回归模型中使用Keras ImageDataGenerator

如何解决《在回归模型中使用KerasImageDataGenerator》经验,为你挑选了2个好方法。

我想用

flow_from_directory

的方法

ImageDataGenerator

生成回归模型的训练数据,其中目标值可以是介于1和-1之间的任何浮点值.

flow_from_directory

有一个带有descripton的"class_mode"参数

class_mode:"分类","二进制","稀疏"或"无"之一.默认值:"分类".确定返回的标签数组的类型:"分类"将是2D单热编码标签,"二进制"将是1D二进制标签,"稀疏"将是1D整数标签.

我应该选择以下哪些值?他们似乎都不适合......



1> Marcin Możej..:

目前(2017年1月21日发布的最新版本的Keras)flow_from_directory只能以以下方式工作:

    您需要以以下方式构造目录:

    directory with images\
        1st label\
            1st picture from 1st label
            2nd picture from 1st label
            3rd picture from 1st label
            ...
        2nd label\
            1st picture from 2nd label
            2nd picture from 2nd label
            3rd picture from 2nd label
            ...
        ...
    

    flow_from_directory以的格式返回固定大小的批次(picture, label)

因此,如您所见,它只能用于分类案例,并且文档中提供的所有选项仅指定将类提供给分类器的方式。但是,有一个简洁的技巧可以flow_from_directory对回归任务有用:

    您需要按照以下方式构建目录:

    directory with images\
        1st value (e.g. -0.95423)\
            1st picture from 1st value
            2nd picture from 1st value
            3rd picture from 1st value
            ...
        2nd value (e.g. - 0.9143242)\
            1st picture from 2nd value
            2nd picture from 2nd value
            3rd picture from 2nd value
            ...
       ...
    

    您还需要一个列表list_of_values = [1st value, 2nd value, ...]。然后按照以下方式定义生成器:

    def regression_flow_from_directory(flow_from_directory_gen, list_of_values):
        for x, y in flow_from_directory_gen:
            yield x, list_of_values[y]
    

而对于flow_from_directory_gena来说class_mode='sparse',使这项工作至关重要。当然,这有点麻烦,但它可以工作(我使用了此解决方案:))



2> 小智..:

在Keras 2.2.4中,您可以使用“ .flow_from_dataframe”解决您想做的事情,允许您从目录中流出图像以解决回归问题。您应该将所有图像存储在一个文件夹中,并加载一个数据帧,该数据帧的一列包含图像ID,另一列包含回归分数(标签),并在“ .flow_from_dataframe”中设置“ class_mode ='other'”。

在这里,您可以找到一个示例,其中图像位于“ image_dir”中,带有图像ID和回归分数的数据框加载有来自“训练文件”中的熊猫

train_label_df = pd.read_csv(train_file, delimiter=' ', header=None, names=['id', 'score'])

train_datagen = ImageDataGenerator(rescale = 1./255, horizontal_flip = True,
                                   fill_mode = "nearest", zoom_range = 0.2,
                                   width_shift_range = 0.2, height_shift_range=0.2,
                                   rotation_range=30) 

train_generator = train_datagen.flow_from_dataframe(dataframe=train_label_df, directory=image_dir, 
                                              x_col="id", y_col="score", has_ext=True, 
                                              class_mode="other", target_size=(img_width, img_height), 
                                              batch_size=bs)

推荐阅读
TXCWB_523
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有