当前位置:  开发笔记 > 编程语言 > 正文

如何在tensorflow中使用tf.while_loop()进行可变长度输入?

如何解决《如何在tensorflow中使用tf.while_loop()进行可变长度输入?》经验,为你挑选了1个好方法。



1> Yaroslav Bul..:

如果从所有变量中删除形状,它都有效:

import tensorflow as tf
import numpy as np

config = tf.ConfigProto(graph_options=tf.GraphOptions(
  optimizer_options=tf.OptimizerOptions(opt_level=tf.OptimizerOptions.L0)))
tf.reset_default_graph()
sess = tf.Session("", config=config)
#initial_m = tf.Variable(0.0, name='m')

#The code no longer works after I change shape=(4) to shape=(None)
inputs = tf.placeholder(dtype='float32', shape=(None)) 
time_steps = tf.shape(inputs)[0]
initial_outputs = tf.TensorArray(dtype=tf.float32, size=time_steps)
initial_t = tf.placeholder(dtype='int32')
initial_m = tf.placeholder(dtype=tf.float32)

def should_continue(t, *args):
    return t < time_steps

def iteration(t, m, outputs_):
    cur = tf.gather(inputs, t)
    m  = m * 0.5 + cur * 0.5
    outputs_ = outputs_.write(t, m)
    return t + 1, m, outputs_

t, m, outputs = tf.while_loop(should_continue, iteration,
                              [initial_t, initial_m, initial_outputs])

outputs = outputs.stack()
init = tf.global_variables_initializer()
sess.run([init])
print(sess.run([outputs],
               feed_dict={inputs: np.asarray([1, 1, 1, 1]), initial_t: 0,
                          initial_m: 0.}))

推荐阅读
放ch养奶牛
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有