我已经使用tf.data.Dataset
API 训练了模型,所以我的训练代码看起来像这样
with graph.as_default(): dataset = tf.data.TFRecordDataset(tfrecord_path) dataset = dataset.map(scale_features, num_parallel_calls=n_workers) dataset = dataset.shuffle(10000) dataset = dataset.padded_batch(batch_size, padded_shapes={...}) handle = tf.placeholder(tf.string, shape=[]) iterator = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types, train_dataset.output_shapes) batch = iterator.get_next() ... # Model code ... iterator = dataset.make_initializable_iterator() with tf.Session(graph=graph) as sess: train_handle = sess.run(iterator.string_handle()) sess.run(tf.global_variables_initializer()) for epoch in range(n_epochs): sess.run(train_iterator.initializer) while True: try: sess.run(optimizer, feed_dict={handle: train_handle}) except tf.errors.OutOfRangeError: break
现在,在训练完模型之后,我想推断出数据集中没有的示例,而且我不确定该怎么做。
明确地说,我知道如何使用另一个数据集,例如,我只是在测试时将句柄传递给测试集。
The question is about given the scaling scheme and the fact that the network expects a handle, if I want to make a prediction to a new example which is not written to a TFRecord, how would I go about doing that?
If I'd modify the batch
I'd be responsible for the scaling beforehand which is something I would like to avoid if possible.
So how should I infer single examples from a model traiend the tf.data.Dataset
way?
(This is not for production purposes it is for evaluating what will happen if I change specific features)