当前位置:  开发笔记 > 编程语言 > 正文

在同一Tensorflow会话中从Saver加载两个模型

如何解决《在同一Tensorflow会话中从Saver加载两个模型》经验,为你挑选了1个好方法。

我有两个网络:一个Model生成输出,一个Adversary对输出进行分级.

两者都经过单独培训,但现在我需要在单个会话期间将它们的输出结合起来.

我试图实现这篇文章中提出的解决方案:同时运行多个预先训练的Tensorflow网络

我的代码

with tf.name_scope("model"):
    model = Model(args)
with tf.name_scope("adv"):
    adversary = Adversary(adv_args)

#...

with tf.Session() as sess:
    tf.global_variables_initializer().run()

    # Get the variables specific to the `Model`
    # Also strip out the surperfluous ":0" for some reason not saved in the checkpoint
    model_varlist = {v.name.lstrip("model/")[:-2]: v 
                     for v in tf.global_variables() if v.name[:5] == "model"}
    model_saver = tf.train.Saver(var_list=model_varlist)
    model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
    model_saver.restore(sess, model_ckpt.model_checkpoint_path)

    # Get the variables specific to the `Adversary`
    adv_varlist = {v.name.lstrip("avd/")[:-2]: v 
                   for v in tf.global_variables() if v.name[:3] == "adv"}
    adv_saver = tf.train.Saver(var_list=adv_varlist)
    adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
    adv_saver.restore(sess, adv_ckpt.model_checkpoint_path)

问题

对函数的调用model_saver.restore()似乎什么都不做.在另一个模块中,我使用了一个saver,tf.train.Saver(tf.global_variables())它恢复了检查点.

该模型有model.tvars = tf.trainable_variables().要检查发生了什么,我曾经sess.run()提取tvars恢复之前和之后.每次使用初始随机分配的变量时,都不会分配检查点的变量.

对于为什么model_saver.restore()似乎什么都不做的任何想法?



1> TheCriticalI..:

解决这个问题花了很长时间,所以我发布了我可能不完美的解决方案,以防其他人需要它.

为了诊断问题,我手动循环遍历每个变量并逐个分配它们.然后我注意到在分配变量后,名称会改变.这里描述:TensorFlow检查点保存和读取

根据该帖子中的建议,我在自己的图表中运行了每个模型.这也意味着我必须在自己的会话中运行每个图形.这意味着以不同方式处理会话管理.

首先我创建了两个图表

model_graph = tf.Graph()
with model_graph.as_default():
    model = Model(args)

adv_graph = tf.Graph()
with adv_graph.as_default():
    adversary = Adversary(adv_args)

然后两节

adv_sess = tf.Session(graph=adv_graph)
sess = tf.Session(graph=model_graph)

然后我在每个会话中初始化变量并分别恢复每个图形

with sess.as_default():
    with model_graph.as_default():
        tf.global_variables_initializer().run()
        model_saver = tf.train.Saver(tf.global_variables())
        model_ckpt = tf.train.get_checkpoint_state(args.save_dir)
        model_saver.restore(sess, model_ckpt.model_checkpoint_path)

with adv_sess.as_default():
    with adv_graph.as_default():
        tf.global_variables_initializer().run()
        adv_saver = tf.train.Saver(tf.global_variables())
        adv_ckpt = tf.train.get_checkpoint_state(adv_args.save_dir)
        adv_saver.restore(adv_sess, adv_ckpt.model_checkpoint_path)

每当需要每个会话时,我都会tf在这个会话中包含任何函数with sess.as_default():.最后我手动关闭会话

sess.close()
adv_sess.close()

推荐阅读
虎仔球妈_459
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有