我正在研究Google关于如何在Android上部署和使用预先训练的Tensorflow图(模型)的示例.此示例使用以下.pb
文件:
https://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip
这是指向自动下载的文件的链接.
该示例显示了如何将.pb
文件加载到Tensorflow会话并使用它来执行分类,但是.pb
在训练图形之后(例如,在Python中)似乎没有提到如何生成这样的文件.
有没有关于如何做到这一点的例子?
编辑:该freeze_graph.py
脚本是TensorFlow存储库的一部分,现在可用作生成协议缓冲区的工具,该协议缓冲区表示来自现有TensorFlow GraphDef
和已保存检查点的"冻结"训练模型.它使用与下面描述的相同的步骤,但它更容易使用.
目前该过程没有很好的记录(并且需要改进),但大致的步骤如下:
建立并训练您的模型作为tf.Graph
被调用者g_1
.
获取每个变量的最终值并将它们存储为numpy数组(使用Session.run()
).
在新的tf.Graph
调用中g_2
,tf.constant()
使用在步骤2中获取的相应numpy数组的值为每个变量创建张量.
使用tf.import_graph_def()
从复制节点g_1
进入g_2
,并使用input_map
参数替换每个变量g_1
与相应的tf.constant()
在第三步建立张量你也可能需要使用input_map
指定新的输入张量(如更换输入管道用tf.placeholder()
).使用return_elements
参数指定预测输出张量的名称.
调用g_2.as_graph_def()
以获取图的协议缓冲区表示.
(注意:生成的图形将在图形中有额外的节点用于训练.虽然它不是公共API的一部分,但您可能希望使用内部graph_util.extract_sub_graph()
函数从图形中剥离这些节点.)
freeze_graph()
除了我之前使用的答案之外,只有将其称为脚本才有用,有一个非常好的功能可以为您完成所有繁重的工作,并且适合从您的普通模型训练代码中调用.
convert_variables_to_constants()
做两件事:
它通过用常量替换变量来冻结权重
它删除与前馈预测无关的节点
假设sess
您的tf.Session()
并且"output"
是您的预测节点的名称,以下代码将您的最小图表序列化为文本和二进制protobuf.
from tensorflow.python.framework.graph_util import convert_variables_to_constants minimal_graph = convert_variables_to_constants(sess, sess.graph_def, ["output"]) tf.train.write_graph(minimal_graph, '.', 'minimal_graph.proto', as_text=False) tf.train.write_graph(minimal_graph, '.', 'minimal_graph.txt', as_text=True)