I'm trying to implement a simple gradient boosting algorithm for regression in R. This is what I came up with so far but the error is not plateauing like I had expected. Any suggestions?
data("OrchardSprays")
niter <- 10
learn <- 0.05
y <- OrchardSprays$decrease
yhat <- rep(0,nrow(OrchardSprays))
weight <- rep(1,nrow(OrchardSprays))
loss <- function(y,yhat) (y - yhat)^2
for (i in seq(niter))
{
model <- lm(decrease~.,weights=weight,data=OrchardSprays)
yhat <- yhat + weight * (predict(model) - yhat) / i
error <- mean(loss(y,yhat))
weight <- weight + learn * (loss(y,yhat) - error) / error
cat(i,"error:",error,"\n")
}
output:
1 error: 319.5881
2 error: 318.6175
3 error: 317.9368
4 error: 317.6112
5 error: 317.6369
6 error: 317.9772
7 error: 318.5833
8 error: 319.4047
9 error: 320.3939
10 error: 321.5086
Not sure if this helps, but if you lower your starting weights and increase the number of iterations, the error will get much closer to zero. However, it still doesn't plateau (error starts going back up at iteration 103). I also added the following statement: weight <- ifelse(weight < 0.0, 0.0, weight)
to compensate for the missing or negative weights not allowed
error generated by the lm
function.
data("OrchardSprays")
niter <- 105
learn <- 0.05
y <- OrchardSprays$decrease
yhat <- rep(0.0,nrow(OrchardSprays))
weight <- rep(0.2,nrow(OrchardSprays))
loss <- function(y,yhat) (y - yhat)^2
error <- mean(loss(y,yhat))
cat("initial error:",error,"\n")
for (i in seq(niter))
{
model <- lm(decrease~.,weights=weight,data=OrchardSprays)
yhat <- yhat + weight * (predict(model) - yhat) / i
error <- mean(loss(y,yhat))
weight <- weight + learn * (loss(y,yhat) - error) / error
weight <- ifelse(weight < 0.0, 0.0, weight)
cat(i,"error:",error,"\n")
}
ouput:
initial error: 3308.922
1 error: 2232.762
2 error: 1707.971
3 error: 1360.834
4 error: 1110.503
5 error: 921.2804
6 error: 776.4314
7 error: 663.5947
8 error: 574.2603
9 error: 502.2455
10 error: 443.2639
11 error: 394.2983
12 error: 353.1736
13 error: 318.2869
14 error: 288.4326
15 error: 262.6827
16 error: 240.3086
17 error: 220.7289
18 error: 203.4741
19 error: 188.1632
20 error: 174.4876
21 error: 162.1971
22 error: 151.0889
23 error: 140.9982
24 error: 131.7907
25 error: 123.3567
26 error: 115.6054
27 error: 108.4606
28 error: 101.8571
29 error: 95.73825
30 error: 90.05343
31 error: 84.75755
32 error: 79.81715
33 error: 75.19618
34 error: 70.86006
35 error: 66.77859
36 error: 62.92584
37 error: 59.28014
38 error: 55.8239
39 error: 52.54784
40 error: 49.44272
41 error: 46.49915
42 error: 43.71022
43 error: 41.07119
44 error: 38.57908
45 error: 36.23237
46 error: 34.03907
47 error: 32.00558
48 error: 30.12923
49 error: 28.39891
50 error: 26.80582
51 error: 25.33449
52 error: 23.97077
53 error: 22.70327
54 error: 21.52714
55 error: 20.43589
56 error: 19.42552
57 error: 18.48629
58 error: 17.60916
59 error: 16.78986
60 error: 16.02315
61 error: 15.30303
62 error: 14.62663
63 error: 13.99066
64 error: 13.39205
65 error: 12.82941
66 error: 12.30349
67 error: 11.811
68 error: 11.34883
69 error: 10.91418
70 error: 10.50448
71 error: 10.11723
72 error: 9.751116
73 error: 9.405197
74 error: 9.076175
75 error: 8.761231
76 error: 8.458107
77 error: 8.165144
78 error: 7.884295
79 error: 7.615498
80 error: 7.356618
81 error: 7.106186
82 error: 6.86324
83 error: 6.627176
84 error: 6.39777
85 error: 6.17544
86 error: 5.961616
87 error: 5.756781
88 error: 5.561157
89 error: 5.375131
90 error: 5.19945
91 error: 5.034539
92 error: 4.880956
93 error: 4.739453
94 error: 4.610629
95 error: 4.495216
96 error: 4.393571
97 error: 4.306144
98 error: 4.233587
99 error: 4.176799
100 error: 4.136802
101 error: 4.114575
102 error: 4.111308
103 error: 4.1278
104 error: 4.164539
105 error: 4.221389
I'll admit to not having written a weight-optimizer in ages, so I may be off base. I'd start by recording the yhat
vector on every iteration. See if the values are either oscillating or disappearing towards zero (as I'm not sure whether you're helping or hurting by dividing by i
) .
Similarly, take a look at the R^2 values from each iteration of lm(). If they're very close to 1 you may simply have run into the currently prescribed sensitivity limit of lm().
It would be helpful if you could provide the source of your algorithm so we could check the code against the equations you're implementing.
Update: A quick look at wikipedia yields the following: "Several open-source R packages are available: gbm,[6] mboost, gbev." I strongly recommend you study those packages, including their source code, to see if they'll meet your needs.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With