我在tensorflow中训练生成对抗网络(GAN),基本上我们有两个不同的网络,每个网络都有自己的优化器.
self.G, self.layer = self.generator(self.inputCT,batch_size_tf) self.D, self.D_logits = self.discriminator(self.GT_1hot) ... self.g_optim = tf.train.MomentumOptimizer(self.learning_rate_tensor, 0.9).minimize(self.g_loss, global_step=self.global_step) self.d_optim = tf.train.AdamOptimizer(self.learning_rate, beta1=0.5) \ .minimize(self.d_loss, var_list=self.d_vars)
问题是我首先训练其中一个网络(g)然后,我想一起训练g和d.但是,当我调用load函数时:
self.sess.run(tf.initialize_all_variables()) self.sess.graph.finalize() self.load(self.checkpoint_dir) def load(self, checkpoint_dir): print(" [*] Reading checkpoints...") ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: ckpt_name = os.path.basename(ckpt.model_checkpoint_path) self.saver.restore(self.sess, ckpt.model_checkpoint_path) return True else: return False
我有这样的错误(有更多的追溯):
Tensor name "beta2_power" not found in checkpoint files checkpoint/MR2CT.model-96000
我可以恢复g网络并继续使用该功能进行训练,但是当我想从头开始标记d,并且从存储的模型中获取g时,我有这个错误.
要恢复变量子集,必须创建一个新变量tf.train.Saver
并将其传递给可选var_list
参数中要恢复的特定变量列表.
默认情况下,a tf.train.Saver
将创建操作,以便(i)在调用时保存图形中的每个变量,saver.save()
以及(ii)在调用时查找(按名称)给定检查点中的每个变量saver.restore()
.虽然这适用于大多数常见方案,但您必须提供更多信息以使用变量的特定子集:
如果您只想恢复变量的子集,可以通过调用获取这些变量的列表tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=G_NETWORK_PREFIX)
,假设您将"g"网络放在公共with tf.name_scope(G_NETWORK_PREFIX):
或tf.variable_scope(G_NETWORK_PREFIX):
块中.然后,您可以将此列表传递给tf.train.Saver
构造函数.
如果要还原变量的子集和/或检查点中的变量具有不同的名称,则可以将字典作为var_list
参数传递.默认情况下,检查点中的每个变量都与一个键相关联,该键是其tf.Variable.name
属性的值.如果目标图中的名称不同(例如,因为您添加了作用域前缀),则可以指定将字符串键(在检查点文件中)映射到tf.Variable
对象(在目标图中)的字典.