当前位置:  开发笔记 > 人工智能 > 正文

TensorFlow:当批次完成培训时,tf.train.batch会自动加载下一批吗?

如何解决《TensorFlow:当批次完成培训时,tf.train.batch会自动加载下一批吗?》经验,为你挑选了1个好方法。

例如,在我创建操作后,通过操作提供批处理数据并运行操作,tf.train.batch是否自动将另一批数据输入到会话中?

我问这个是因为tf.train.batch的属性allow_smaller_final_batch使得最终批次的加载大小小于指定的批量大小.这是否意味着即使没有循环,下一批可以自动进给?从教程代码我很困惑.当我加载一个批处理时,我实际上只有一个批量大小的形状[batch_size,height,width,num_channels],但是文档说它Creates batches of tensors in tensors.也是,当我在tf-slim演练教程中阅读教程代码时,一个名为load_batch的函数,只返回了3个张量:images, images_raw, labels.如文档中所述,"批量"数据在哪里?

谢谢您的帮助.



1> bodokaiser..:

... tf.train.batch会自动将另一批数据输入会话吗?

不会.没有什 您必须sess.run(...)再次呼叫才能加载新批次.

这是否意味着即使没有循环,下一批可以自动进给?

tf.train.batch(..)将永远载入batch_size张量.如果您有100个图像,batch_size=30那么您将有3*30个批次,因为您可以sess.run(batch)在输入队列从头开始(或停止if epoch=1)之前调用三次.这意味着您错过了100-3*30=10培训样本.如果你不想错过它们,你可以这样做tf.train.batch(..., allow_smaller_final_batch=True),现在你将有3x 30样本批次和1x 10样本批处理,然后输入队列将重新启动.

我还要详细说明代码示例:

queue = tf.train.string_input_producer(filenames,
        num_epochs=1) # only iterate through all samples in dataset once

reader = tf.TFRecordReader() # or any reader you need
_, example = reader.read(queue)

image, label = your_conversion_fn(example)

# batch will now load up to 100 image-label-pairs on sess.run(...)
# most tf ops are tuned to work on batches
# this is faster and also gives better result on e.g. gradient calculation
batch = tf.train.batch([image, label], batch_size=100)

with tf.Session() as sess:
    # "boilerplate" code
    sess.run([
        tf.local_variables_initializer(),
        tf.global_variables_initializer(),
    ])
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        # in most cases coord.should_stop() will return True
        # when there are no more samples to read
        # if num_epochs=0 then it will run for ever
        while not coord.should_stop():
            # will start reading, working data from input queue
            # and "fetch" the results of the computation graph
            # into raw_images and raw_labels
            raw_images, raw_labels = sess.run([images, labels])
    finally:
        coord.request_stop()
        coord.join(threads)

推荐阅读
殉情放开那只小兔子
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有