当前位置:  开发笔记 > 编程语言 > 正文

Tensorflow:在cpu上的多个线程中加载数据

如何解决《Tensorflow:在cpu上的多个线程中加载数据》经验,为你挑选了1个好方法。

我有一个python类SceneGenerator,它有多个成员函数用于预处理和生成器函数generate_data().基本结构是这样的:

class SceneGenerator(object):
    def __init__(self):
       # some inits

    def generate_data(self):
        """
        Generator. Yield data X and labels y after some preprocessing
        """
        while True:
            # opening files, selecting data
            X,y = self.preprocess(some_params, filenames, ...)            

            yield X, y

我在keras model.fit_generator()函数中使用了类成员函数sceneGenerator.generate_data()来从磁盘读取数据,对其进行预处理并将其生成.在keras中,如果workers参数of model.fit_generator()设置为> 1 ,则在多个CPU线程上完成.

我现在想SceneGenerator在tensorflow中使用相同的类.我目前的做法是:

sceneGenerator = SceneGenerator(some_params)
for X, y in sceneGenerator.generate_data():

    feed_dict = {ops['data']: X,
                 ops['labels']: y,
                 ops['is_training_pl']: True
                 }
    summary, step, _, loss, prediction = sess.run([optimization_op, loss_op, pred_op],
                                                  feed_dict=feed_dict)

但是,这很慢并且不使用多个线程.我发现tf.data.Datasetapi有一些文档,但我没有实现这些方法.

编辑:请注意,我不使用图像,因此带有文件路径等的图像加载机制在此处不起作用.我SceneGenerator从hdf5文件加载数据.但不是完整的数据集,而是 - 取决于初始化参数 - 只有数据集的一部分.我希望保持生成器功能不变,并了解如何将此生成器直接用作tensorflow的输入并在CPU上的多个线程上运行.将数据从hdf5文件重写为csv不是一个好选择,因为它复制了大量数据.

编辑2 ::我认为类似的东西可以帮助:并行化tf.data.Dataset.from_generator



1> GPhilo..:

假设您正在使用最新的Tensorflow(在撰写本文时为1.4),则可以保留生成器并按以下方式使用tf.data.*API(我为线程号,预取缓冲区大小,批处理大小和输出数据类型选择了任意值) :

NUM_THREADS = 5
sceneGen = SceneGenerator()
dataset = tf.data.Dataset.from_generator(sceneGen.generate_data, output_types=(tf.float32, tf.int32))
dataset = dataset.map(lambda x,y : (x,y), num_parallel_calls=NUM_THREADS).prefetch(buffer_size=1000)
dataset = dataset.batch(42)
X, y = dataset.make_one_shot_iterator().get_next()

为了表明实际上是从生成器中提取的多个线程,我对类进行了如下修改:

import threading    
class SceneGenerator(object):
  def __init__(self):
    # some inits
    pass

  def generate_data(self):
    """
    Generator. Yield data X and labels y after some preprocessing
    """
    while True:
      # opening files, selecting data
      X,y = threading.get_ident(), 2 #self.preprocess(some_params, filenames, ...)            
      yield X, y

这样,创建一个Tensorflow会话并获取一批将显示获取数据的线程的线程ID。在我的电脑上,运行:

sess = tf.Session()
print(sess.run([X, y]))

版画

[array([  8460.,   8460.,   8460.,  15912.,  16200.,  16200.,   8460.,
         15912.,  16200.,   8460.,  15912.,  16200.,  16200.,   8460.,
         15912.,  15912.,   8460.,   8460.,   6552.,  15912.,  15912.,
          8460.,   8460.,  15912.,   9956.,  16200.,   9956.,  16200.,
         15912.,  15912.,   9956.,  16200.,  15912.,  16200.,  16200.,
         16200.,   6552.,  16200.,  16200.,   9956.,   6552.,   6552.], dtype=float32),
 array([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])]

注意:您可能想尝试删除map调用(我们仅使用它来拥有多个线程),并检查prefetch的缓冲区是否足以消除输入管道中的瓶颈(即使只有一个线程,通常也需要对输入进行预处理)比实际的图形执行速度更快,因此缓冲区足以使预处理尽可能快地进行。

推荐阅读
wangtao
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有