我有一个关于通过tensorflow python api更新张量值的基本问题.
考虑代码段:
x = tf.placeholder(shape=(None,10), ... ) y = tf.placeholder(shape=(None,), ... ) W = tf.Variable( randn(10,10), dtype=tf.float32 ) yhat = tf.matmul(x, W)
现在让我们假设我想实现某种迭代更新W值的算法(例如一些优化算法).这将包括以下步骤:
for i in range(max_its): resid = y_hat - y W = f(W , resid) # some update
这里的问题是W
在LHS上是一个新的张量,而不是W
用于yhat = tf.matmul(x, W)
!也就是说,创建了一个新变量,并且W
我的"模型"中使用的值不会更新.
现在解决这个问题的方法就是这样
for i in range(max_its): resid = y_hat - y W = f(W , resid) # some update yhat = tf.matmul( x, W)
这导致为我的循环的每次迭代创建一个新的"模型"!
是否有更好的方法来实现这个(在python中)而不为循环的每次迭代创建一大堆新模型 - 而是更新原始张量W
"就地"可以这么说?
变量有一个assign方法.尝试:W.assign(f(W,resid))