我无法理解中的图论点tf.Session()
。我尝试查找TensorFlow网站:链接,但了解得不多。
我试图找出tf.Session()
和之间的区别tf.Session(graph=some_graph_inserted_here)
。
def predict(): with tf.name_scope("predict"): with tf.Session() as sess: saver = tf.train.import_meta_graph("saved_models/testing.meta") saver.restore(sess, "saved_models/testing") loaded_graph = tf.get_default_graph() output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0') _x = loaded_graph.get_tensor_by_name('x:0') print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})
此代码给出以下错误:ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used
尝试在以下位置加载图形时saver = tf.train.import_meta_graph("saved_models/testing.meta")
def predict(): with tf.name_scope("predict"): loaded_graph = tf.Graph() with tf.Session(graph=loaded_graph) as sess: saver = tf.train.import_meta_graph("saved_models/testing.meta") saver.restore(sess, "saved_models/testing") output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0') _x = loaded_graph.get_tensor_by_name('x:0') print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})
如果替换为loaded_graph = tf.Graph()
,则代码不起作用loaded_graph = tf.get_default_graph()
。为什么?
完整的代码是否有帮助:(https://gist.github.com/duemaster/f8cf05c0923ebabae476b83e895619ab)
TensorFlow Graph
是一个对象,其中包含您的各种tf.Tensor
和tf.Operation
。
当您创建这些张量(例如使用tf.Variable
或tf.constant
)或操作(例如tf.matmul
)时,它们将被添加到默认图(查看graph
这些对象的成员以获取它们所属的图)。如果您未指定任何内容,它将是您在调用tf.get_default_graph
方法时获得的图形。
但是您也可以使用上下文管理器来处理多个图形:
g = tf.Graph() with g.as_default(): [your code]
假设您在代码中创建了多个图形,那么您需要将图形放置并作为tf.Session
方法的参数运行以指定要运行的TensorFlow。
在代码A中,您
使用默认图形,
尝试将元图导入其中(失败,因为它已经包含一些节点),并且,
会将模型还原到其中,
在使用代码B时,您
创建一个新的新图,
将元图导入其中(之所以成功,是因为它是一个空图),然后
恢复它。
tf.Graph
API
这段代码使代码A可以工作(我将默认图形重置为新的图形,并删除了Forecast name_scope
)。
def predict(): tf.reset_default_graph() with tf.Session() as sess: saver = tf.train.import_meta_graph("saved_models/testing.meta") saver.restore(sess, "saved_models/testing") loaded_graph = tf.get_default_graph() output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0') _x = loaded_graph.get_tensor_by_name('x:0') print(sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])}))