我正在通过TensorFlow的CIFAR-10示例开始CNN入门指南
现在在cifar10_train.py的火车功能中我们得到的图像为
images,labels = cifar10.distorted_inputs()
在distorted_inputs()
函数中,我们在队列中生成文件名,然后将单个记录读取为
# Create a queue that produces the filenames to read.
filename_queue = tf.train.string_input_producer(filenames)
# Read examples from files in the filename queue.
read_input = cifar10_input.read_cifar10(filename_queue)
reshaped_image = tf.cast(read_input.uint8image, tf.float32)
当我添加调试代码时,该read_input
变量只包含一条带有图像的记录及其高度,宽度和标签名称.
然后,该示例将一些失真应用于读取的图像/记录,然后将其传递给_generate_image_and_label_batch()
函数.
该函数然后返回形状的4D张量[batch_size, 32, 32, 3]
,其中batch_size = 128
.
tf.train.shuffle_batch()
返回批处理时,上述功能使用该功能.
我的问题是功能中额外的记录来自tf.train.shuffle_batch()
哪里?我们没有传递任何文件名或读者对象.
有人可以说明我们如何从1条记录转到128条记录吗?我查看了文档,但不明白.
该tf.train.shuffle_batch()
函数可用于生成包含一批输入的(一个或多个)张量.在内部,tf.train.shuffle_batch()
创建一个tf.RandomShuffleQueue
,它q.enqueue()
使用图像和标签张量调用,以排列单个元素(图像标签对).然后返回结果q.dequeue_many(batch_size)
,将batch_size
随机选择的元素(图像 - 标签对)连接成一批图像和一批标签.
请注意,虽然它从代码看起来像read_input
并filename_queue
具有功能关系,但还有一个额外的皱纹.简单地评估结果tf.train.shuffle_batch()
将永远阻止,因为没有元素添加到内部队列.为了简化这一点,当您调用时tf.train.shuffle_batch()
,TensorFlow将QueueRunner
在图形中添加一个内部集合.稍后调用tf.train.start_queue_runners()
(例如此处cifar10_train.py
)将启动一个向队列添加元素的线程,并使训练继续进行.该线程和队列HOWTO对如何工作的更多信息.