Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Stochastic gradient Descent implementation - MATLAB

I'm trying to implement "Stochastic gradient descent" in MATLAB. I followed the algorithm exactly but I'm getting a VERY VERY large w (coffients) for the prediction/fitting function. Do I have a mistake in the algorithm ?

The Algorithm : enter image description here

x = 0:0.1:2*pi      // X-axis
    n = size(x,2);      
    r = -0.2+(0.4).*rand(n,1);  //generating random noise to be added to the sin(x) function

    t=zeros(1,n);
    y=zeros(1,n);



    for i=1:n
        t(i)=sin(x(i))+r(i);          // adding the noise
        y(i)=sin(x(i));               // the function without noise
    end

    f = round(1+rand(20,1)*n);        //generating random indexes

    h = x(f);                         //choosing random x points
    k = t(f);                         //chossing random y points

    m=size(h,2);                     // length of the h vector

    scatter(h,k,'Red');              // drawing the training points (with noise)
    %scatter(x,t,2);
    hold on;
    plot(x,sin(x));                 // plotting the Sin function


    w = [0.3 1 0.5];                    // starting point of w
    a=0.05;                         // learning rate "alpha"

// ---------------- ALGORITHM ---------------------//
    for i=1:20
        v = [1 h(i) h(i).^2];                      // X vector
        e = ((w*v') - k(i)).*v;            // prediction - observation
        w = w - a*e;                       // updating w
    end

    hold on;

    l = 0:1:6;
    g = w(1)+w(2)*l+w(3)*(l.^2);
    plot(l,g,'Yellow');                      // drawing the prediction function
like image 607
Morano88 Avatar asked Feb 25 '11 12:02

Morano88


2 Answers

If you use too big learning rate, SGD is likely to diverge.
The learing rate should converge to zero.

like image 145
Łukasz Lew Avatar answered Sep 25 '22 07:09

Łukasz Lew


typically, if w ended up with too large values, there is overfitting. I didn't really look at your code carefully. But I think, what is missing from your code is a proper regularization term, which prevents the training overfitting. Also, here:

e = ((w*v') - k(i)).*v;

The v here is not the gradient of the predicted value, isn't it? According to algorithm, you should replace it. Let's see how it will be like after doing this.

like image 37
Hotloo Xiranood Avatar answered Sep 25 '22 07:09

Hotloo Xiranood