当前位置:  开发笔记 > 编程语言 > 正文

如何在Tensorflow图中(在GPU上)保留计算值?

如何解决《如何在Tensorflow图中(在GPU上)保留计算值?》经验,为你挑选了1个好方法。

我们如何确保计算值不会被复制回CPU/python内存,但仍可用于下一步的计算?

以下代码显然不会这样做:

import tensorflow as tf

a = tf.Variable(tf.constant(1.),name="a")
b = tf.Variable(tf.constant(2.),name="b")
result = a + b
stored = result

with tf.Session() as s:
    val = s.run([result,stored],{a:1.,b:2.})
    print(val) # 3
    val=s.run([result],{a:4.,b:5.})
    print(val) # 9
    print(stored.eval()) # 3  NOPE:

错误:尝试使用未初始化的值_recv_b_0



1> Anona112..:

答案是tf.Variable通过使用assign操作将值存储到a中来存储:

工作代码:

import tensorflow as tf
with tf.Session() as s:
    a = tf.Variable(tf.constant(1.),name="a")
    b = tf.Variable(tf.constant(2.),name="b")
    result = a + b
    stored  = tf.Variable(tf.constant(0.),name="stored_sum")
    assign_op=stored.assign(result)
    val,_ = s.run([result,assign_op],{a:1.,b:2.})
    print(val) # 3
    val=s.run(result,{a:4.,b:5.})
    print(val[0]) # 9
    print(stored.eval()) # ok, still 3 

推荐阅读
360691894_8a5c48
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有