所以我的问题是我正在运行TensorFlow教程中的初学者级别代码,并根据我的需要对其进行了修改,但是当我使用它print sess.run(accuracy, feed_dict={x: x_test, y_: y_test})
时总是打印出1.0,现在它总是猜测0并打印出~93%准确性.当我使用时tf.argmin(y,1), tf.argmin(y_,1)
,它会猜测所有1并且产生约7%的准确率.将两者相加,它等于100%.我不知道如何tf.argmin
猜测1并tf.argmax
猜测0.显然代码有问题.请看一下,让我知道我可以做些什么来解决这个问题.我认为在培训期间代码出错了,但我可能错了.
import tensorflow as tf import numpy as np from numpy import genfromtxt data = genfromtxt('cs-training.csv',delimiter=',') # Training data test_data = genfromtxt('cs-test.csv',delimiter=',') # Test data x_train = [] for i in data: x_train.append(i[1:]) x_train = np.array(x_train) y_train = [] for i in data: if i[0] == 0: y_train.append([1., i[0]]) else: y_train.append([0., i[0]]) y_train = np.array(y_train) where_are_NaNs = isnan(x_train) x_train[where_are_NaNs] = 0 x_test = [] for i in test_data: x_test.append(i[1:]) x_test = np.array(x_test) y_test = [] for i in test_data: if i[0] == 0: y_test.append([1., i[0]]) else: y_test.append([0., i[0]]) y_test = np.array(y_test) where_are_NaNs = isnan(x_test) x_test[where_are_NaNs] = 0 x = tf.placeholder("float", [None, 10]) W = tf.Variable(tf.zeros([10,2])) b = tf.Variable(tf.zeros([2])) y = tf.nn.softmax(tf.matmul(x,W) + b) y_ = tf.placeholder("float", [None,2]) cross_entropy = -tf.reduce_sum(y_*tf.log(y)) train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) print "...Training..." g = 0 for i in range(len(x_train)): sess.run(train_step, feed_dict={x: [x_train[g]], y_: [y_train[g]]}) g += 1
在这一点上,如果我做它print [x_train[g]]
和print [y_train[g]]
,这是结果是什么样子.
[array([ 7.66126609e-01, 4.50000000e+01, 2.00000000e+00, 8.02982129e-01, 9.12000000e+03, 1.30000000e+01, 0.00000000e+00, 6.00000000e+00, 0.00000000e+00, 2.00000000e+00])] [array([ 0., 1.])]
好吧,让我们继续吧.
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) print sess.run(accuracy, feed_dict={x: x_test, y_: y_test}) 0.929209
这个百分比不会改变.无论我为2个类(1或0)创建的onehot,它都在猜测所有零.
这是一个数据 -
print x_train[:10] [[ 7.66126609e-01 4.50000000e+01 2.00000000e+00 8.02982129e-01 9.12000000e+03 1.30000000e+01 0.00000000e+00 6.00000000e+00 0.00000000e+00 2.00000000e+00] [ 9.57151019e-01 4.00000000e+01 0.00000000e+00 1.21876201e-01 2.60000000e+03 4.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00] [ 6.58180140e-01 3.80000000e+01 1.00000000e+00 8.51133750e-02 3.04200000e+03 2.00000000e+00 1.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 2.33809776e-01 3.00000000e+01 0.00000000e+00 3.60496820e-02 3.30000000e+03 5.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 9.07239400e-01 4.90000000e+01 1.00000000e+00 2.49256950e-02 6.35880000e+04 7.00000000e+00 0.00000000e+00 1.00000000e+00 0.00000000e+00 0.00000000e+00] [ 2.13178682e-01 7.40000000e+01 0.00000000e+00 3.75606969e-01 3.50000000e+03 3.00000000e+00 0.00000000e+00 1.00000000e+00 0.00000000e+00 1.00000000e+00] [ 3.05682465e-01 5.70000000e+01 0.00000000e+00 5.71000000e+03 0.00000000e+00 8.00000000e+00 0.00000000e+00 3.00000000e+00 0.00000000e+00 0.00000000e+00] [ 7.54463648e-01 3.90000000e+01 0.00000000e+00 2.09940017e-01 3.50000000e+03 8.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 1.16950644e-01 2.70000000e+01 0.00000000e+00 4.60000000e+01 0.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 1.89169052e-01 5.70000000e+01 0.00000000e+00 6.06290901e-01 2.36840000e+04 9.00000000e+00 0.00000000e+00 4.00000000e+00 0.00000000e+00 2.00000000e+00]] print y_train[:10] [[ 0. 1.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.]] print x_test[:20] [[ 4.83539240e-02 4.40000000e+01 0.00000000e+00 3.02297622e-01 7.48500000e+03 1.10000000e+01 0.00000000e+00 1.00000000e+00 0.00000000e+00 2.00000000e+00] [ 9.10224439e-01 4.20000000e+01 5.00000000e+00 1.72900000e+03 0.00000000e+00 5.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 2.92682927e-01 5.80000000e+01 0.00000000e+00 3.66480079e-01 3.03600000e+03 7.00000000e+00 0.00000000e+00 1.00000000e+00 0.00000000e+00 1.00000000e+00] [ 3.11547538e-01 3.30000000e+01 1.00000000e+00 3.55431993e-01 4.67500000e+03 1.10000000e+01 0.00000000e+00 1.00000000e+00 0.00000000e+00 1.00000000e+00] [ 0.00000000e+00 7.20000000e+01 0.00000000e+00 2.16630600e-03 6.00000000e+03 9.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 2.79217052e-01 4.50000000e+01 1.00000000e+00 4.89921122e-01 6.84500000e+03 8.00000000e+00 0.00000000e+00 2.00000000e+00 0.00000000e+00 2.00000000e+00] [ 0.00000000e+00 7.80000000e+01 0.00000000e+00 0.00000000e+00 0.00000000e+00 1.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 9.10363487e-01 2.80000000e+01 0.00000000e+00 4.99451497e-01 6.38000000e+03 8.00000000e+00 0.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00] [ 6.36595797e-01 4.40000000e+01 0.00000000e+00 7.85457163e-01 4.16600000e+03 6.00000000e+00 0.00000000e+00 1.00000000e+00 0.00000000e+00 0.00000000e+00] [ 1.41549211e-01 2.60000000e+01 0.00000000e+00 2.68407434e-01 4.25000000e+03 4.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 4.14101100e-03 7.80000000e+01 0.00000000e+00 2.26362500e-03 5.74200000e+03 7.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 9.99999900e-01 6.00000000e+01 0.00000000e+00 1.20000000e+02 0.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 6.28525944e-01 4.70000000e+01 0.00000000e+00 1.13100000e+03 0.00000000e+00 5.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 2.00000000e+00] [ 4.02283095e-01 6.00000000e+01 0.00000000e+00 3.79442065e-01 8.63800000e+03 1.00000000e+01 0.00000000e+00 1.00000000e+00 0.00000000e+00 0.00000000e+00] [ 5.70997900e-03 8.10000000e+01 0.00000000e+00 2.17382000e-04 2.30000000e+04 4.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 4.71171849e-01 5.10000000e+01 0.00000000e+00 1.53700000e+03 0.00000000e+00 1.40000000e+01 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 1.42395210e-02 8.20000000e+01 0.00000000e+00 7.40466500e-03 2.70000000e+03 1.00000000e+01 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00] [ 4.67455800e-02 3.70000000e+01 0.00000000e+00 1.48010090e-02 9.12000000e+03 8.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 4.00000000e+00] [ 9.99999900e-01 4.70000000e+01 0.00000000e+00 3.54604127e-01 1.10000000e+04 1.10000000e+01 0.00000000e+00 2.00000000e+00 0.00000000e+00 3.00000000e+00] [ 8.96417860e-02 2.70000000e+01 0.00000000e+00 8.14664000e-03 5.40000000e+03 6.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]] print y_test[:20] [[ 1. 0.] [ 0. 1.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 0. 1.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.] [ 1. 0.]]
dga.. 5
tl; dr:上面发布的示例代码计算交叉熵的方式在数值上并不健全.请tf.nn.cross_entropy_with_logits
改用.
(响应问题的v1,已经改变):我担心你的训练实际上并没有完成或工作,基于nan
你展示的x_train数据中的s.我建议先修复它 - 并确定它们出现并修复该错误的原因,并查看nan
您的测试集中是否还有s.可能也有助于显示x_test和y_test.
最后,我相信y_
有关x的处理方式存在错误.代码被编写为好像y_
是一个热门矩阵,但是当你展示时y_train[:10]
,它只有10个元素,而不是10*num_classes
类别.我怀疑那里有一个bug.当你在轴1上进行argmax时,你总是会得到一个充满零的向量(因为那个轴上只有一个元素,所以当然它是最大元素).将它与在估计值上产生始终为零的输出的错误相结合,并且您总是产生"正确"的答案.:)
修订版本的更新 在更改后的版本中,如果运行它并在每次执行结束时打印出W,请将代码更改为如下所示:
_, w_out, b_out = sess.run([train_step, W, b], feed_dict={x: [x_train[g]], y_: [y_train[g]]})
你会发现W充满了nan
s.要对此进行调试,您可以盯着您的代码查看是否存在可以发现的数学问题,或者您可以通过管道检测它们以查看它们出现的位置.我们试试吧.首先,是什么cross_entropy
?(添加cross_entropy
到run
语句中的事物列表并打印出来)
Cross entropy: inf
大!所以为什么?嗯,一个答案就是:
y = [0, 1] tf.log(y) = [-inf, 0]
这是y的有效可能输出,但是你的交叉熵计算不稳定.你可以手动添加一些epsilons以避免角落情况,或者用来tf.nn.softmax_cross_entropy_with_logits
为你做.我推荐后者:
yprime = tf.matmul(x,W)+b y = tf.nn.softmax(yprime) cross_entropy = tf.nn.softmax_cross_entropy_with_logits(yprime, y_)
我不保证您的模型可以正常工作,但这应该可以解决您当前的NaN问题.
tl; dr:上面发布的示例代码计算交叉熵的方式在数值上并不健全.请tf.nn.cross_entropy_with_logits
改用.
(响应问题的v1,已经改变):我担心你的训练实际上并没有完成或工作,基于nan
你展示的x_train数据中的s.我建议先修复它 - 并确定它们出现并修复该错误的原因,并查看nan
您的测试集中是否还有s.可能也有助于显示x_test和y_test.
最后,我相信y_
有关x的处理方式存在错误.代码被编写为好像y_
是一个热门矩阵,但是当你展示时y_train[:10]
,它只有10个元素,而不是10*num_classes
类别.我怀疑那里有一个bug.当你在轴1上进行argmax时,你总是会得到一个充满零的向量(因为那个轴上只有一个元素,所以当然它是最大元素).将它与在估计值上产生始终为零的输出的错误相结合,并且您总是产生"正确"的答案.:)
修订版本的更新 在更改后的版本中,如果运行它并在每次执行结束时打印出W,请将代码更改为如下所示:
_, w_out, b_out = sess.run([train_step, W, b], feed_dict={x: [x_train[g]], y_: [y_train[g]]})
你会发现W充满了nan
s.要对此进行调试,您可以盯着您的代码查看是否存在可以发现的数学问题,或者您可以通过管道检测它们以查看它们出现的位置.我们试试吧.首先,是什么cross_entropy
?(添加cross_entropy
到run
语句中的事物列表并打印出来)
Cross entropy: inf
大!所以为什么?嗯,一个答案就是:
y = [0, 1] tf.log(y) = [-inf, 0]
这是y的有效可能输出,但是你的交叉熵计算不稳定.你可以手动添加一些epsilons以避免角落情况,或者用来tf.nn.softmax_cross_entropy_with_logits
为你做.我推荐后者:
yprime = tf.matmul(x,W)+b y = tf.nn.softmax(yprime) cross_entropy = tf.nn.softmax_cross_entropy_with_logits(yprime, y_)
我不保证您的模型可以正常工作,但这应该可以解决您当前的NaN问题.