Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding multinomial nnet

Tags:

r

nnet

I am trying to understand the code behind nnet. I am currently getting different results when I split a multinomial factor in to the binary columns instead of using the formula method.

library(nnet)

set.seed(123)
y <- class.ind(iris$Species)
x <- as.matrix(iris[,1:4])
fit1 <- nnet(x, y, size = 3, decay = .1)

# weights:  27
#initial  value 164.236516 
#iter  10 value 102.567531
#iter  20 value 58.229722
#iter  30 value 39.720137
#iter  40 value 25.049530
#iter  50 value 23.671837
#iter  60 value 23.602392
#iter  70 value 23.601927
#final  value 23.601926 
#converged

pred1 <- predict(fit1, iris[,1:4])
rowSums(head(pred1))
[1] 1.032197661 1.033700173 1.032750746 1.034229149 1.032052937 1.032539980

set.seed(123)
fit2 <- nnet(Species ~ ., data = iris, size = 3, decay = .1)

# weights:  27
#initial  value 158.508573 
#iter  10 value 37.167558
#iter  20 value 26.815839
#iter  30 value 23.746418
#iter  40 value 23.698182
#iter  50 value 23.697907
#final  value 23.697907 
#converged

pred2 <- predict(fit2, iris[,1:4])
rowSums(head(pred2))
1 2 3 4 5 6 
1 1 1 1 1 1 

I know I can just use the latter approach (formula method) but I want to understand why the results are different when it appears the same method of splitting the factor is in the source code nnet.formula.

like image 875
cdeterman Avatar asked Nov 08 '22 14:11

cdeterman


1 Answers

As noted by @user20650, the softmax argument is different. Inside nnet.formula there is the section:

if (length(lev) == 2L) {
    y <- as.vector(unclass(y)) - 1
    res <- nnet.default(x, y, w, entropy = TRUE, ...)
    res$lev <- lev
}
else {
    y <- class.ind(y)
    res <- nnet.default(x, y, w, softmax = TRUE, ...)
    res$lev <- lev
}

Here the softmax is set to TRUE. Setting it in the nnet call fixes the problem and they match now.

fit <- nnet(x, y, size = 3, decay = .1, softmax = TRUE)
pred <- predict(fit, iris[,1:4])
rowSums(head(pred))
like image 109
cdeterman Avatar answered Nov 15 '22 07:11

cdeterman