当前位置:  开发笔记 > 编程语言 > 正文

使用数据集API在Tensorflow中滑动批处理窗口

如何解决《使用数据集API在Tensorflow中滑动批处理窗口》经验,为你挑选了1个好方法。



1> vijay m..:

可以使用sliding window批处理操作实现tf.data.Dataset:

例:

from tensorflow.contrib.data.python.ops import sliding

imgs = tf.constant(['img0','img1', 'img2','img3', 'img4','img5', 'img6', 'img7'])
labels = tf.constant([0, 0, 0, 1, 1, 1, 0, 0])

# create TensorFlow Dataset object
data = tf.data.Dataset.from_tensor_slices((imgs, labels))

# sliding window batch
window = 4
stride = 1
data = data.apply(sliding.sliding_window_batch(window, stride))

# create TensorFlow Iterator object
iterator =  tf.data.Iterator.from_structure(data.output_types,data.output_shapes)
next_element = iterator.get_next()

# create initialization ops 
init_op = iterator.make_initializer(data)

with tf.Session() as sess:
   # initialize the iterator on the data
   sess.run(init_op)
   while True:
      try:
         elem = sess.run(next_element)
         print(elem)
      except tf.errors.OutOfRangeError:
         print("End of dataset.")
         break

输出:

 (array([b'img0', b'img1', b'img2', b'img3'], dtype=object), array([0, 0, 0, 1], dtype=int32))
 (array([b'img1', b'img2', b'img3', b'img4'], dtype=object), array([0, 0, 1, 1], dtype=int32))
 (array([b'img2', b'img3', b'img4', b'img5'], dtype=object), array([0, 1, 1, 1], dtype=int32))
 (array([b'img3', b'img4', b'img5', b'img6'], dtype=object), array([1, 1, 1, 0], dtype=int32))
 (array([b'img4', b'img5', b'img6', b'img7'], dtype=object), array([1, 1, 0, 0], dtype=int32))


在tensorflow> = 1.12中,你应该使用`data.window(size = window,shift = 1,stride = stride).flat_map(lambda x:x.batch(window))`代替不推荐使用的`data.apply(滑动) .sliding_window_batch(window,stride))`
推荐阅读
手机用户2402852387
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有