当前位置:  开发笔记 > 编程语言 > 正文

在Tensorflow中恢复变量子集

如何解决《在Tensorflow中恢复变量子集》经验,为你挑选了1个好方法。

我在tensorflow中训练生成对抗网络(GAN),基本上我们有两个不同的网络,每个网络都有自己的优化器.

self.G, self.layer = self.generator(self.inputCT,batch_size_tf)
self.D, self.D_logits = self.discriminator(self.GT_1hot)

...

self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step)

self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \
                      .minimize(self.d_loss, var_list=self.d_vars)

问题是我首先训练其中一个网络(g)然后,我想一起训练g和d.但是,当我调用load函数时:

self.sess.run(tf.initialize_all_variables())
self.sess.graph.finalize()

self.load(self.checkpoint_dir)

def load(self, checkpoint_dir):
    print(" [*] Reading checkpoints...")

    ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
    if ckpt and ckpt.model_checkpoint_path:
        ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
        self.saver.restore(self.sess, ckpt.model_checkpoint_path)
        return True
    else:
        return False

我有这样的错误(有更多的追溯):

Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000

我可以恢复g网络并继续使用该功能进行训练,但是当我想从头开始标记d,并且从存储的模型中获取g时,我有这个错误.



1> mrry..:

要恢复变量子集,必须创建一个新变量tf.train.Saver并将其传递给可选var_list参数中要恢复的特定变量列表.

默认情况下,a tf.train.Saver将创建操作,以便(i)在调用时保存图形中的每个变量,saver.save()以及(ii)在调用时查找(按名称)给定检查点中的每个变量saver.restore().虽然这适用于大多数常见方案,但您必须提供更多信息以使用变量的特定子集:

    如果您只想恢复变量的子集,可以通过调用获取这些变量的列表tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX),假设您将"g"网络放在公共with tf.name_scope(G_NETWORK_PREFIX):tf.variable_scope(G_NETWORK_PREFIX):块中.然后,您可以将此列表传递给tf.train.Saver构造函数.

    如果要还原变量的子集和/或检查点中的变量具有不同的名称,则可以将字典作为var_list参数传递.默认情况下,检查点中的每个变量都与一个相关联,该是其tf.Variable.name属性的值.如果目标图中的名称不同(例如,因为您添加了作用域前缀),则可以指定将字符串键(在检查点文件中)映射到tf.Variable对象(在目标图中)的字典.

推荐阅读
wurtjq
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有