当前位置:  开发笔记 > 人工智能 > 正文

TensorFlow中的基本神经网络

如何解决《TensorFlow中的基本神经网络》经验,为你挑选了1个好方法。

我试图在TensorFlow中实现一个非常基本的神经网络,但我遇到了一些问题.这是一个非常基本的网络,作为值(小时或睡眠和学习时间)的输入,并预测测试的分数(我在你的管上找到了这个例子).所以基本上我只有一个隐藏层有三个单元,每个单元计算一个激活函数(sigmoid),成本函数是平方误差的总和,我使用梯度下降来最小化它.所以问题是,当我使用训练数据训练网并尝试使用相同的训练数据进行一些预测时,结果不完全匹配,并且它们看起来也很奇怪,因为看起来彼此相等.

import tensorflow as tf
import numpy as np
import input_data

sess = tf.InteractiveSession()

# create a 2-D version of input for plotting
trX = np.matrix(([3,5], [5,1],[10,2]), dtype=float)
trY = np.matrix(([85], [82], [93]), dtype=float) # 3X1 matrix
trX = trX / np.max(trX, axis=0)
trY = trY / 100 # 100 is the maximum score allowed

teX = np.matrix(([3,5]), dtype=float)
teY = np.matrix(([85]), dtype=float)
teX = teX/np.amax(teX, axis=0)
teY = teY/100

def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

def model(X, w_h, w_o):
    z2 = tf.matmul(X, w_h)
    a2 = tf.nn.sigmoid(z2) # this is a basic mlp, think 2 stacked logistic regressions
    z3 = tf.matmul(a2, w_o)
    yHat = tf.nn.sigmoid(z3)
    return yHat # note that we dont take the softmax at the end because our cost fn does that for us

X = tf.placeholder("float", [None, 2])
Y = tf.placeholder("float", [None, 1])

W1 = init_weights([2, 3]) # create symbolic variables
W2 = init_weights([3, 1])

sess.run(tf.initialize_all_variables())

py_x = model(X, W1, W2)

cost = tf.reduce_mean(tf.square(py_x - Y))
train_op = tf.train.GradientDescentOptimizer(0.5).minimize(cost) # construct an optimizer
predict_op = py_x

sess.run(train_op, feed_dict={X: trX, Y: trY})

print sess.run(predict_op, feed_dict={X: trX})

sess.close()

它产生:

[[0.51873487] [0.51874501] [0.51873082]]

我相信它应该与训练数据结果类似.

我对神经网络和机器学习都很陌生,所以请原谅我任何错误,提前谢谢.



1> mrry..:

您的网络未接受培训的主要原因是声明:

sess.run(train_op, feed_dict={X: trX, Y: trY})

......只执行一次.在TensorFlow中,运行train_op(或从其返回的任何操作Optimizer.minimize()只会导致网络采用单个梯度下降步骤.您应该在循环中执行它以执行迭代训练,并且权重最终会收敛.

另外两个提示:(i)如果您在每个步骤中提供训练数据的子集而不是整个数据集,则可以实现更快的收敛; (ii)学习率0.5可能太高(虽然这取决于数据).

推荐阅读
linjiabin43
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有