当前位置:  开发笔记 > 开发工具 > 正文

tf.Session()中的图参数有什么作用?

如何解决《tf.Session()中的图参数有什么作用?》经验,为你挑选了1个好方法。

我无法理解中的图论点tf.Session()。我尝试查找TensorFlow网站:链接,但了解得不多。

我试图找出tf.Session()和之间的区别tf.Session(graph=some_graph_inserted_here)

问题背景

代码A(无效):

def predict():
    with tf.name_scope("predict"):
        with tf.Session() as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            loaded_graph = tf.get_default_graph()
            output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
            _x = loaded_graph.get_tensor_by_name('x:0')
            print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})

此代码给出以下错误:ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used尝试在以下位置加载图形时saver = tf.train.import_meta_graph("saved_models/testing.meta")

代码B(工作中):

def predict():
    with tf.name_scope("predict"):
        loaded_graph = tf.Graph()
        with tf.Session(graph=loaded_graph) as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
            _x = loaded_graph.get_tensor_by_name('x:0')
            print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})

如果替换为loaded_graph = tf.Graph(),则代码不起作用loaded_graph = tf.get_default_graph()。为什么?

完整的代码是否有帮助:(https://gist.github.com/duemaster/f8cf05c0923ebabae476b83e895619ab)



1> pfm..:

TensorFlow Graph是一个对象,其中包含您的各种tf.Tensortf.Operation

当您创建这些张量(例如使用tf.Variabletf.constant)或操作(例如tf.matmul)时,它们将被添加到默认图(查看graph这些对象的成员以获取它们所属的图)。如果您未指定任何内容,它将是您在调用tf.get_default_graph方法时获得的图形。

但是您也可以使用上下文管理器来处理多个图形:

g = tf.Graph()
with g.as_default():
    [your code]

假设您在代码中创建了多个图形,那么您需要将图形放置并作为tf.Session方法的参数运行以指定要运行的TensorFlow。

在代码A中,您

使用默认图形,

尝试将元图导入其中(失败,因为它已经包含一些节点),并且,

会将模型还原到其中,

在使用代码B时,您

创建一个新的新图,

将元图导入其中(之所以成功,是因为它是一个空图),然后

恢复它。

有用的链接:

tf.Graph API

编辑:

这段代码使代码A可以工作(我将默认图形重置为新的图形,并删除了Forecast name_scope)。

def predict():
    tf.reset_default_graph()
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph("saved_models/testing.meta")
        saver.restore(sess, "saved_models/testing")
        loaded_graph = tf.get_default_graph()
        output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
        _x = loaded_graph.get_tensor_by_name('x:0')
        print(sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])}))

推荐阅读
LEEstarmmmmm
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有