当前位置:  开发笔记 > 编程语言 > 正文

线性回归的梯度下降不起作用

如何解决《线性回归的梯度下降不起作用》经验,为你挑选了1个好方法。

我正在尝试为线性回归实现一个简单的梯度下降算法.我正在使用Armadillo C++线性代数库,而且我也是Armadillo的新手.这就是我想要做的:

void linRegression(mat &features, mat &targets, double alpha,double error){
    mat theta = ones(features.n_cols+1);
    mat temp = zeros(features.n_cols+1);
    mat features_new = join_horiz(ones(features.n_rows),features);
    mat predictions;
    double con = alpha*(1.0/features.n_rows);
    int j = 0;
    while(j<1000){
        mat step_error = (features_new*theta - targets);
        for(unsigned int i=0;i

但theta的值不断增加并最终达到无穷大.我不确定我做错了什么.



1> Anton..:

我认为while循环中的计算不正确.至少你可以在没有for-loop的情况下做得更优雅.以下是1个功能问题的简短代码:

#include 
#include 

using namespace std;
using namespace arma;

int main(int argc, char** argv)
{
    mat features(10, 1);

    features << 6.110100 << endr
         << 5.527700 << endr
         << 8.518600 << endr
         << 7.003200 << endr
         << 5.859800 << endr
         << 8.382900 << endr
         << 7.476400 << endr
         << 8.578100 << endr
         << 6.486200 << endr
         << 5.054600 << endr;

    mat targets(10, 1);

    targets << 17.59200 << endr
        << 9.130200 << endr
        << 13.66200 << endr
        << 11.85400 << endr
        << 6.823300 << endr
        << 11.88600 << endr
        << 4.348300 << endr
        << 12.00000 << endr
        << 6.598700 << endr
        << 3.816600 << endr;

    mat theta = ones(features.n_cols + 1);

    mat features_new = join_horiz(ones(features.n_rows), features);

    double alpha = 0.01;
    double con = alpha*(1.0 / features.n_rows);

    int j = 0;

    while (j < 20000){
        mat step_error = (features_new*theta - targets);
        theta = theta - con * (features_new.t() * step_error);
        j++;
    }

    theta.print("theta:");

    system("pause");

    return 0;
}

该计划返回:

theta:
   0.5083
   1.3425

通过正规方程方法得到的结果是:

theta:
   0.5071
   1.3427

编辑

你的代码确实是正确的!问题可能出在功能标准化中.我将我的示例扩展为2个特征回归并添加规范化.如果没有标准化,它对我也不起作用.

#include 
#include 

using namespace std;
using namespace arma;

int main(int argc, char** argv)
{

    mat features(10, 2);

    features << 2104 << 3 << endr
         << 1600 << 3 << endr
         << 2400 << 3 << endr
         << 1416 << 2 << endr
         << 3000 << 4 << endr
         << 1985 << 4 << endr
         << 1534 << 3 << endr
         << 1427 << 3 << endr
         << 1380 << 3 << endr
         << 1494 << 3 << endr;

    mat m = mean(features, 0);
    mat s = stddev(features, 0, 0);

    int i,  j;

    //normalization
    for (i = 0; i < features.n_rows; i++)
    {
        for (j = 0; j < features.n_cols; j++)
        {
            features(i, j) = (features(i, j) - m(j))/s(j);
        }
    }

    mat targets(10, 1);

    targets << 399900 << endr
        << 329900 << endr
        << 369000 << endr
        << 232000 << endr
        << 539900 << endr
        << 299900 << endr
        << 314900 << endr
        << 198999 << endr
        << 212000 << endr
        << 242500 << endr;


    mat theta = ones(features.n_cols + 1);

    mat features_new = join_horiz(ones(features.n_rows), features);

    double alpha = 0.01;
    double con = alpha*(1.0 / features.n_rows);

    while (j < 20000){
        mat step_error = (features_new*theta - targets);
        theta = theta - con * (features_new.t() * step_error);
        j++;
    }

    cout << theta << endl;

    system("pause");

    return 0;
}

结果:

THETA:

  3.1390e+005
  9.9704e+004
 -5.6835e+003

推荐阅读
郑小蒜9299_941611_G
这个屌丝很懒,什么也没留下!
DevBox开发工具箱 | 专业的在线开发工具网站    京公网安备 11010802040832号  |  京ICP备19059560号-6
Copyright © 1998 - 2020 DevBox.CN. All Rights Reserved devBox.cn 开发工具箱 版权所有