当前位置:  开发笔记 > 编程语言 > 正文

TensorFlow:关于tf.argmax()和tf.equal()的问题

如何解决《TensorFlow:关于tf.argmax()和tf.equal()的问题》经验,为你挑选了2个好方法。

我正在学习TensorFlow,构建一个多层感知器模型.我正在研究一些例子:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/multilayer_perceptron.ipynb

然后,我在下面的代码中有一些问题:

def multilayer_perceptron(x, weights, biases):
    :
    :

pred = multilayer_perceptron(x, weights, biases)
    :
    :

with tf.Session() as sess:
    sess.run(init)
         :
    correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))

    accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
    print ("Accuracy:", accuracy.eval({x: X_test, y: y_test_onehot}))

我想知道做什么tf.argmax(prod,1)tf.argmax(y,1)意味着什么,返回(类型和价值)究竟是什么?并且是correct_prediction变量而不是实际值?

最后,我们如何从tf会话中获取y_test_prediction数组(输入数据时的预测结果X_test)?非常感谢!



1> 小智..:
tf.argmax(input, axis=None, name=None, dimension=None)

返回张量轴上具有最大值的索引.

输入是Tensor,轴描述输入Tensor的哪个轴要减少.对于矢量,使用axis = 0.

对于您的具体情况,让我们使用两个数组并演示这一点

pred = np.array([[31, 23,  4, 24, 27, 34],
                [18,  3, 25,  0,  6, 35],
                [28, 14, 33, 22, 20,  8],
                [13, 30, 21, 19,  7,  9],
                [16,  1, 26, 32,  2, 29],
                [17, 12,  5, 11, 10, 15]])

y = np.array([[31, 23,  4, 24, 27, 34],
                [18,  3, 25,  0,  6, 35],
                [28, 14, 33, 22, 20,  8],
                [13, 30, 21, 19,  7,  9],
                [16,  1, 26, 32,  2, 29],
                [17, 12,  5, 11, 10, 15]])

评估tf.argmax(pred, 1)给出了评估会给出的张量array([5, 5, 2, 1, 3, 0])

评估tf.argmax(y, 1)给出了评估会给出的张量array([5, 5, 2, 1, 3, 0])

tf.equal(x, y, name=None) takes two tensors(x and y) as inputs and returns the truth value of (x == y) element-wise. 

按照我们的例子,tf.equal(tf.argmax(pred, 1),tf.argmax(y, 1))返回一个评估将给出的张量array(1,1,1,1,1,1).

correct_prediction是一个张量,其评估将给出0和1的1-D数组

y_test_prediction可以通过执行获得 pred = tf.argmax(logits, 1)

可以通过以下链接访问tf.argmax和tf.equal的文档.

tf.argmax()https://www.tensorflow.org/api_docs/python/math_ops/sequence_comparison_and_indexing#argmax

tf.equal()https://www.tensorflow.org/versions/master/api_docs/python/control_flow_ops/comparison_operators#equal



2> Salvador Dal..:

阅读文档:

tf.argmax

返回张量轴上具有最大值的索引.

tf.equal

以元素方式返回(x == y)的真值.

tf.cast

将张量转换为新类型.

tf.reduce_mean

计算张量维度的元素平均值.


现在您可以轻松解释它的作用.你y是一个热门编码,所以它有一个1,所有其他都是零.你pred代表了阶级的概率.因此argmax找到最佳预测和正确值的位置.之后,检查它们是否相同.

所以现在你correct_prediction是一个True/False值的向量,其大小等于你想要预测的实例数.您将其转换为浮点数并取平均值.


实际上,在评估模型部分的TF教程中很好地解释了这一部分

推荐阅读
雯颜哥_135
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有