当前位置:  开发笔记 > 编程语言 > 正文

TensorFlow检查点保存并读取

如何解决《TensorFlow检查点保存并读取》经验,为你挑选了1个好方法。

我有一个基于TensorFlow的神经网络和一组变量.

培训功能如下:

def train(load = True, step)
    """
    Defining the neural network is skipped here
    """

    train_step = tf.train.AdamOptimizer(1e-4).minimize(mse)
    # Saver
    saver = tf.train.Saver()

    if not load:
        # Initalizing variables
        sess.run(tf.initialize_all_variables())
    else:
        saver.restore(sess, 'Variables/map.ckpt')
        print 'Model Restored!'

    # Perform stochastic gradient descent
    for i in xrange(step):
        train_step.run(feed_dict = {x: train, y_: label})

    # Save model
    save_path = saver.save(sess, 'Variables/map.ckpt')
    print 'Model saved in file: ', save_path
    print 'Training Done!'

我正在调用这样的训练函数:

# First train
train(False, 1)
# Following train
for i in xrange(10):
    train(True, 10)

我做了这种训练,因为我需要将不同的数据集提供给我的模型.但是,如果我以这种方式调用train函数,TensorFlow将生成错误消息,指示它无法从文件中读取已保存的模型.

经过一些实验,我发现这是因为检查点保存很慢.在将文件写入磁盘之前,下一个列车功能将开始读取,从而产生错误.

我曾尝试使用time.sleep()函数在每次调用之间做一些延迟,但它不起作用.

任何人都知道如何解决这种写/读错误?非常感谢你!



1> mrry..:

您的代码中存在一个微妙的问题:每次调用train()函数时,对于所有模型变量和神经网络的其余部分,将更多节点添加到同一TensorFlow图中.这意味着每次构造a时tf.train.Saver(),它都包含以前调用的所有变量train().每次重新创建模型时,都会使用额外的_N后缀创建变量,以便为它们指定唯一的名称:

    Saver用变量构造var_a,var_b.

    包裹带变量构成var_a,var_b,var_a_1,var_b_1.

    包裹带变量构成var_a,var_b,var_a_1,var_b_1,var_a_2,var_b_2.

    等等

a的默认行为tf.train.Saver是将每个变量与相应op的名称相关联.这意味着var_a_1不会初始化var_a,因为它们最终会有不同的名称.

解决方案是每次调用时创建一个新图表train().解决它的最简单方法是更改​​主程序,为每个调用创建一个新图形,train()如下所示:

# First train
with tf.Graph().as_default():
    train(False, 1)

# Following train
for i in xrange(10):
    with tf.Graph().as_default():
        train(True, 10)

...或者,等效地,您可以withtrain()函数内移动块.

推荐阅读
家具销售_903
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有