Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Gradient descent in Java

I've recently started the AI-Class at Coursera and I've a question related to my implementation of the gradient descent algorithm.

Here's my current implementation (I actually just "translated" the mathematical expressions into Java code):

    public class GradientDescent {

    private static final double TOLERANCE = 1E-11;
    
    private double theta0;
    private double theta1;
    
    public double getTheta0() {
        return theta0;
    }
    
    public double getTheta1() {
        return theta1;
    }
    
    public GradientDescent(double theta0, double theta1) {
         this.theta0 = theta0;
         this.theta1 = theta1;
    }
    
    public double getHypothesisResult(double x){
        return theta0 + theta1*x;
    }
    
    private double getResult(double[][] trainingData, boolean enableFactor){
        double result = 0;
        for (int i = 0; i < trainingData.length; i++) {
            result = (getHypothesisResult(trainingData[i][0]) - trainingData[i][1]);
            if (enableFactor) result = result*trainingData[i][0]; 
        }
        return result;
    }
    
    public void train(double learningRate, double[][] trainingData){
        int iteration = 0;
        double delta0, delta1;
        do{
            iteration++;
            System.out.println("SUBS: " + (learningRate*((double) 1/trainingData.length))*getResult(trainingData, false));
            double temp0 = theta0 - learningRate*(((double) 1/trainingData.length)*getResult(trainingData, false));
            double temp1 = theta1 - learningRate*(((double) 1/trainingData.length)*getResult(trainingData, true));
            delta0 = theta0-temp0; delta1 = theta1-temp1;
            theta0 = temp0; theta1 = temp1;
        }while((Math.abs(delta0) + Math.abs(delta1)) > TOLERANCE);
        System.out.println(iteration);
    }
}

The code works quite well but only if I choose an very little alpha, here called learningRate. If it's higher than 0.00001, it diverges.

Do you have any suggestions on how to optimize the implementation, or an explanation for the "Alpha-Issue" and a possible solution for it?

Update:

Here's the main including some sample inputs:

private static final double[][] TDATA = {{200, 20000},{300, 41000},{900, 141000},{800, 41000},{400, 51000},{500, 61500}};

public static void main(String[] args) {
    GradientDescent gd = new GradientDescent(0,0);
    gd.train(0.00001, TDATA);
    System.out.println("THETA0: " + gd.getTheta0() + " - THETA1: " + gd.getTheta1());
    System.out.println("PREDICTION: " + gd.getHypothesisResult(300));
}

The mathematical expression of gradient descent is as follows:

enter image description here

like image 237
Bastian Avatar asked Aug 23 '15 18:08

Bastian


2 Answers

To solve this issue, it's necessary to normalize the data with this formular: (Xi-mu)/s. Xi is the current training set value, mu the average of values in the current column and s the maximum value minus the minimum value of the current column. This formula will get the training data approximately into a range between -1 and 1 which allowes to choose higher learning rates and gradient descent to converge faster. But it's afterwards necessary to denormalize the predicted result.

like image 156
Bastian Avatar answered Oct 14 '22 12:10

Bastian


private double getResult(double[][] trainingData, boolean enableFactor){
double result = 0;
for (int i = 0; i < trainingData.length; i++) {
    result = (getHypothesisResult(trainingData[i][0]) - trainingData[i][1]);
    if (enableFactor) result = result*trainingData[i][0]; 
}
return result;

In this func. result variable overwritten each iteration, the old value being lost. When inputing the values only the last item on array is calculating. Rest of them dont matter.

like image 27
Semih SÜZEN Avatar answered Oct 14 '22 12:10

Semih SÜZEN