我有两个网络:一个Model
生成输出,一个Adversary
对输出进行分级.
两者都经过单独培训,但现在我需要在单个会话期间将它们的输出结合起来.
我试图实现这篇文章中提出的解决方案:同时运行多个预先训练的Tensorflow网络
我的代码
with tf.name_scope("model"): model = Model(args) with tf.name_scope("adv"): adversary = Adversary(adv_args) #... with tf.Session() as sess: tf.global_variables_initializer().run() # Get the variables specific to the `Model` # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint model_varlist = {v.name.lstrip("model/")[:-2]: v for v in tf.global_variables() if v.name[:5] == "model"} model_saver = tf.train.Saver(var_list=model_varlist) model_ckpt = tf.train.get_checkpoint_state(args.save_dir) model_saver.restore(sess, model_ckpt.model_checkpoint_path) # Get the variables specific to the `Adversary` adv_varlist = {v.name.lstrip("avd/")[:-2]: v for v in tf.global_variables() if v.name[:3] == "adv"} adv_saver = tf.train.Saver(var_list=adv_varlist) adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir) adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)
问题
对函数的调用model_saver.restore()
似乎什么都不做.在另一个模块中,我使用了一个saver,tf.train.Saver(tf.global_variables())
它恢复了检查点.
该模型有model.tvars = tf.trainable_variables()
.要检查发生了什么,我曾经sess.run()
提取tvars
恢复之前和之后.每次使用初始随机分配的变量时,都不会分配检查点的变量.
对于为什么model_saver.restore()
似乎什么都不做的任何想法?
解决这个问题花了很长时间,所以我发布了我可能不完美的解决方案,以防其他人需要它.
为了诊断问题,我手动循环遍历每个变量并逐个分配它们.然后我注意到在分配变量后,名称会改变.这里描述:TensorFlow检查点保存和读取
根据该帖子中的建议,我在自己的图表中运行了每个模型.这也意味着我必须在自己的会话中运行每个图形.这意味着以不同方式处理会话管理.
首先我创建了两个图表
model_graph = tf.Graph() with model_graph.as_default(): model = Model(args) adv_graph = tf.Graph() with adv_graph.as_default(): adversary = Adversary(adv_args)
然后两节
adv_sess = tf.Session(graph=adv_graph) sess = tf.Session(graph=model_graph)
然后我在每个会话中初始化变量并分别恢复每个图形
with sess.as_default(): with model_graph.as_default(): tf.global_variables_initializer().run() model_saver = tf.train.Saver(tf.global_variables()) model_ckpt = tf.train.get_checkpoint_state(args.save_dir) model_saver.restore(sess, model_ckpt.model_checkpoint_path) with adv_sess.as_default(): with adv_graph.as_default(): tf.global_variables_initializer().run() adv_saver = tf.train.Saver(tf.global_variables()) adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir) adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)
每当需要每个会话时,我都会tf
在这个会话中包含任何函数with sess.as_default():
.最后我手动关闭会话
sess.close() adv_sess.close()