Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Simple gradient boosting algorithm

I'm trying to implement a simple gradient boosting algorithm for regression in R. This is what I came up with so far but the error is not plateauing like I had expected. Any suggestions?

data("OrchardSprays")
niter  <- 10
learn  <- 0.05
y      <- OrchardSprays$decrease
yhat   <- rep(0,nrow(OrchardSprays))
weight <- rep(1,nrow(OrchardSprays))
loss   <- function(y,yhat) (y - yhat)^2

for (i in seq(niter))
{
    model  <- lm(decrease~.,weights=weight,data=OrchardSprays)
    yhat   <- yhat + weight * (predict(model) - yhat) / i
    error  <- mean(loss(y,yhat))
    weight <- weight + learn * (loss(y,yhat) - error) / error
    cat(i,"error:",error,"\n")
}

output:

1 error: 319.5881 
2 error: 318.6175 
3 error: 317.9368 
4 error: 317.6112 
5 error: 317.6369 
6 error: 317.9772 
7 error: 318.5833 
8 error: 319.4047 
9 error: 320.3939 
10 error: 321.5086 
like image 623
darckeen Avatar asked Sep 12 '11 16:09

darckeen


2 Answers

Not sure if this helps, but if you lower your starting weights and increase the number of iterations, the error will get much closer to zero. However, it still doesn't plateau (error starts going back up at iteration 103). I also added the following statement: weight <- ifelse(weight < 0.0, 0.0, weight) to compensate for the missing or negative weights not allowed error generated by the lm function.

data("OrchardSprays")
niter  <- 105
learn  <- 0.05
y      <- OrchardSprays$decrease
yhat   <- rep(0.0,nrow(OrchardSprays))
weight <- rep(0.2,nrow(OrchardSprays))
loss   <- function(y,yhat) (y - yhat)^2

error  <- mean(loss(y,yhat))
cat("initial error:",error,"\n")

for (i in seq(niter))
{   
    model  <- lm(decrease~.,weights=weight,data=OrchardSprays)
    yhat   <- yhat + weight * (predict(model) - yhat) / i
    error  <- mean(loss(y,yhat))
    weight <- weight + learn * (loss(y,yhat) - error) / error
    weight <- ifelse(weight < 0.0, 0.0, weight)
    cat(i,"error:",error,"\n")
}

ouput:

initial error: 3308.922

1 error: 2232.762 
2 error: 1707.971 
3 error: 1360.834 
4 error: 1110.503 
5 error: 921.2804 
6 error: 776.4314 
7 error: 663.5947 
8 error: 574.2603 
9 error: 502.2455 
10 error: 443.2639 
11 error: 394.2983 
12 error: 353.1736 
13 error: 318.2869 
14 error: 288.4326 
15 error: 262.6827 
16 error: 240.3086 
17 error: 220.7289 
18 error: 203.4741 
19 error: 188.1632 
20 error: 174.4876 
21 error: 162.1971 
22 error: 151.0889 
23 error: 140.9982 
24 error: 131.7907 
25 error: 123.3567 
26 error: 115.6054 
27 error: 108.4606 
28 error: 101.8571 
29 error: 95.73825 
30 error: 90.05343 
31 error: 84.75755 
32 error: 79.81715 
33 error: 75.19618 
34 error: 70.86006 
35 error: 66.77859 
36 error: 62.92584 
37 error: 59.28014 
38 error: 55.8239 
39 error: 52.54784 
40 error: 49.44272 
41 error: 46.49915 
42 error: 43.71022 
43 error: 41.07119 
44 error: 38.57908 
45 error: 36.23237 
46 error: 34.03907 
47 error: 32.00558 
48 error: 30.12923 
49 error: 28.39891 
50 error: 26.80582 
51 error: 25.33449 
52 error: 23.97077 
53 error: 22.70327 
54 error: 21.52714 
55 error: 20.43589 
56 error: 19.42552 
57 error: 18.48629 
58 error: 17.60916 
59 error: 16.78986 
60 error: 16.02315 
61 error: 15.30303 
62 error: 14.62663 
63 error: 13.99066 
64 error: 13.39205 
65 error: 12.82941 
66 error: 12.30349 
67 error: 11.811 
68 error: 11.34883 
69 error: 10.91418 
70 error: 10.50448 
71 error: 10.11723 
72 error: 9.751116 
73 error: 9.405197 
74 error: 9.076175 
75 error: 8.761231 
76 error: 8.458107 
77 error: 8.165144 
78 error: 7.884295 
79 error: 7.615498 
80 error: 7.356618 
81 error: 7.106186 
82 error: 6.86324 
83 error: 6.627176 
84 error: 6.39777 
85 error: 6.17544 
86 error: 5.961616 
87 error: 5.756781 
88 error: 5.561157 
89 error: 5.375131 
90 error: 5.19945 
91 error: 5.034539 
92 error: 4.880956 
93 error: 4.739453 
94 error: 4.610629 
95 error: 4.495216 
96 error: 4.393571 
97 error: 4.306144 
98 error: 4.233587 
99 error: 4.176799 
100 error: 4.136802 
101 error: 4.114575 
102 error: 4.111308 
103 error: 4.1278 
104 error: 4.164539 
105 error: 4.221389
like image 54
Quadbinary Avatar answered Oct 05 '22 08:10

Quadbinary


I'll admit to not having written a weight-optimizer in ages, so I may be off base. I'd start by recording the yhat vector on every iteration. See if the values are either oscillating or disappearing towards zero (as I'm not sure whether you're helping or hurting by dividing by i ) .
Similarly, take a look at the R^2 values from each iteration of lm(). If they're very close to 1 you may simply have run into the currently prescribed sensitivity limit of lm().

It would be helpful if you could provide the source of your algorithm so we could check the code against the equations you're implementing.

Update: A quick look at wikipedia yields the following: "Several open-source R packages are available: gbm,[6] mboost, gbev." I strongly recommend you study those packages, including their source code, to see if they'll meet your needs.

like image 30
Carl Witthoft Avatar answered Oct 05 '22 08:10

Carl Witthoft