我试图从答案中实现一个建议: Tensorflow:如何保存/恢复模型?
我有一个对象,它tensorflow
以一种sklearn
风格包装模型.
import tensorflow as tf class tflasso(): saver = tf.train.Saver() def __init__(self, learning_rate = 2e-2, training_epochs = 5000, display_step = 50, BATCH_SIZE = 100, ALPHA = 1e-5, checkpoint_dir = "./", ): ... def _create_network(self): ... def _load_(self, sess, checkpoint_dir = None): if checkpoint_dir: self.checkpoint_dir = checkpoint_dir print("loading a session") ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: self.saver.restore(sess, ckpt.model_checkpoint_path) else: raise Exception("no checkpoint found") return def fit(self, train_X, train_Y , load = True): self.X = train_X self.xlen = train_X.shape[1] # n_samples = y.shape[0] self._create_network() tot_loss = self._create_loss() optimizer = tf.train.AdagradOptimizer( self.learning_rate).minimize(tot_loss) # Initializing the variables init = tf.initialize_all_variables() " training per se" getb = batchgen( self.BATCH_SIZE) yvar = train_Y.var() print(yvar) # Launch the graph NUM_CORES = 3 # Choose how many cores to use. sess_config = tf.ConfigProto(inter_op_parallelism_threads=NUM_CORES, intra_op_parallelism_threads=NUM_CORES) with tf.Session(config= sess_config) as sess: sess.run(init) if load: self._load_(sess) # Fit all training data for epoch in range( self.training_epochs): for (_x_, _y_) in getb(train_X, train_Y): _y_ = np.reshape(_y_, [-1, 1]) sess.run(optimizer, feed_dict={ self.vars.xx: _x_, self.vars.yy: _y_}) # Display logs per epoch step if (1+epoch) % self.display_step == 0: cost = sess.run(tot_loss, feed_dict={ self.vars.xx: train_X, self.vars.yy: np.reshape(train_Y, [-1, 1])}) rsq = 1 - cost / yvar logstr = "Epoch: {:4d}\tcost = {:.4f}\tR^2 = {:.4f}".format((epoch+1), cost, rsq) print(logstr ) self.saver.save(sess, self.checkpoint_dir + 'model.ckpt', global_step= 1+ epoch) print("Optimization Finished!") return self
当我跑:
tfl = tflasso() tfl.fit( train_X, train_Y , load = False)
我得到输出:
Epoch: 50 cost = 38.4705 R^2 = -1.2036 b1: 0.118122 Epoch: 100 cost = 26.4506 R^2 = -0.5151 b1: 0.133597 Epoch: 150 cost = 22.4330 R^2 = -0.2850 b1: 0.142261 Epoch: 200 cost = 20.0361 R^2 = -0.1477 b1: 0.147998
但是,当我尝试恢复参数时(即使没有杀死对象):
tfl.fit( train_X, train_Y , load = True)
我得到了奇怪的结果.首先,加载的值与保存的值不对应.
loading a session loaded b1: 0.1 <------- Loaded another value than saved Epoch: 50 cost = 30.8483 R^2 = -0.7670 b1: 0.137484
加载的正确方法是什么,可能首先检查保存的变量?
TL; DR:你应该尝试重做这个类,这样就self.create_network()
可以(i)只调用一次,然后(ii)调用tf.train.Saver()
它.
这里有两个微妙的问题,这是由于代码结构和tf.train.Saver
构造函数的默认行为.当您构造一个没有参数的保护程序时(如在您的代码中),它会收集程序中的当前变量集,并将操作添加到图中以保存和恢复它们.在你的代码中,当你调用时tflasso()
,它将构造一个saver,并且不会有变量(因为create_network()
尚未被调用).因此,检查点应为空.
第二个问题是 - 默认情况下 - 保存的检查点的格式是从name
变量属性到其当前值的映射.如果您创建两个具有相同名称的变量,它们将由TensorFlow自动"不加":
v = tf.Variable(..., name="weights") assert v.name == "weights" w = tf.Variable(..., name="weights") assert v.name == "weights_1" # The "_1" is added by TensorFlow.
这样做的结果是,当您self.create_network()
在第二次调用时调用时tfl.fit()
,变量将与存储在检查点中的名称具有不同的名称 - 或者如果在网络之后构建了保护程序,则可能是这样.(您可以通过将名称Variable
字典传递给saver构造函数来避免此行为,但这通常很尴尬.)
有两个主要的解决方法:
在每次调用中tflasso.fit()
,重新创建整个模型,通过定义一个新的tf.Graph
,然后在该图中构建网络并创建一个tf.train.Saver
.
推荐创建网络,然后tf.train.Saver
在tflasso
构造函数中创建,并在每次调用时重用此图tflasso.fit()
.请注意,您可能需要做更多的工作,以重组的事情(尤其是我不知道你做什么self.X
和self.xlen
),但它应该是可能的,实现这个占位符和喂养.