我正在使用Open Images Dataset上的TensorFlow对象检测API对SSD对象检测器进行微调。我的训练数据包含不平衡的课程,例如
顶部(5K图像)
连衣裙(50K图像)
等等...
我想为分类损失增加类权重以提高性能。我怎么做?配置文件的以下部分似乎相关:
loss { classification_loss { weighted_sigmoid { } } localization_loss { weighted_smooth_l1 { } } ... classification_weight: 1.0 localization_weight: 1.0 }
如何更改配置文件以添加每个类的分类损失权重?如果不通过配置文件,建议采取哪种方式?
API期望直接在注释文件中为每个对象(bbox)分配权重。由于这一要求,使用类权重的解决方案似乎是:
1)如果您有自定义数据集,则可以修改每个对象(bbox)的注释,以将权重字段包括为“对象/权重”。
2)如果您不想修改注释,则可以仅重新创建tf_records文件,以包括bbox的权重。
3)修改API的代码(对我来说似乎很棘手)
我决定去#2,所以我在这里放了代码,为具有权重(1.0,0.1)的两个类(“ top”,“ dress”)的自定义数据集生成了此类加权 tf记录文件,并给出了xml注释文件夹如:
import os import io import glob import hashlib import pandas as pd import xml.etree.ElementTree as ET import tensorflow as tf import random from PIL import Image from object_detection.utils import dataset_util # Define the class names and their weight class_names = ['top', 'dress', ...] class_weights = [1.0, 0.1, ...] def create_example(xml_file): tree = ET.parse(xml_file) root = tree.getroot() image_name = root.find('filename').text image_path = root.find('path').text file_name = image_name.encode('utf8') size=root.find('size') width = int(size[0].text) height = int(size[1].text) xmin = [] ymin = [] xmax = [] ymax = [] classes = [] classes_text = [] truncated = [] poses = [] difficult_obj = [] weights = [] # Important line for member in root.findall('object'): xmin.append(float(member[4][0].text) / width) ymin.append(float(member[4][1].text) / height) xmax.append(float(member[4][2].text) / width) ymax.append(float(member[4][3].text) / height) difficult_obj.append(0) class_name = member[0].text class_id = class_names.index(class_name) weights.append(class_weights[class_id]) if class_name == 'top': classes_text.append('top'.encode('utf8')) classes.append(1) elif class_name == 'dress': classes_text.append('dress'.encode('utf8')) classes.append(2) else: print('E: class not recognized!') truncated.append(0) poses.append('Unspecified'.encode('utf8')) full_path = image_path with tf.gfile.GFile(full_path, 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = Image.open(encoded_jpg_io) if image.format != 'JPEG': raise ValueError('Image format not JPEG') key = hashlib.sha256(encoded_jpg).hexdigest() #create TFRecord Example example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/filename': dataset_util.bytes_feature(file_name), 'image/source_id': dataset_util.bytes_feature(file_name), 'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')), 'image/encoded': dataset_util.bytes_feature(encoded_jpg), 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')), 'image/object/bbox/xmin': dataset_util.float_list_feature(xmin), 'image/object/bbox/xmax': dataset_util.float_list_feature(xmax), 'image/object/bbox/ymin': dataset_util.float_list_feature(ymin), 'image/object/bbox/ymax': dataset_util.float_list_feature(ymax), 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), 'image/object/class/label': dataset_util.int64_list_feature(classes), 'image/object/difficult': dataset_util.int64_list_feature(difficult_obj), 'image/object/truncated': dataset_util.int64_list_feature(truncated), 'image/object/view': dataset_util.bytes_list_feature(poses), 'image/object/weight': dataset_util.float_list_feature(weights) # Important line })) return example def main(_): weighted_tf_records_output = 'name_of_records_file.record' # output file annotations_path = '/path/to/annotations/folder/*.xml' # input annotations writer_train = tf.python_io.TFRecordWriter(weighted_tf_records_output) filename_list=tf.train.match_filenames_once(annotations_path) init = (tf.global_variables_initializer(), tf.local_variables_initializer()) sess=tf.Session() sess.run(init) list = sess.run(filename_list) random.shuffle(list) for xml_file in list: print('-> Processing {}'.format(xml_file)) example = create_example(xml_file) writer_train.write(example.SerializeToString()) writer_train.close() print('-> Successfully converted dataset to TFRecord.') if __name__ == '__main__': tf.app.run()
如果您有其他类型的注释,则代码将非常相似,但不幸的是,此代码将无法工作。