Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Error in predict() glmnet function: not-yet-implemented method

Tags:

r

glmnet

When i use the predict glmnet function, i get the error mentioned below the code.

mydata <- read.csv("data.csv")
x <- mydata[,1:4]
y <- mydata[,5]
data <- cbind(x,y)
model <- model.matrix(y~., data=data)
ridgedata <- model[,-1]
train <- sample(1:dim(ridgedata)[1], round(0.8*dim(ridgedata)[1]))
test <- setdiff(1:dim(ridgedata)[1],train)
x_train <- data[train, ]
y_train <- data$y[train]
x_test <- data[test, ]
y_test <- data$y[test]
k=5
grid =10^seq(10,-2, length =100)
fit <- cv.glmnet(model,y,k=k,lambda = grid)
lambda_min <- fit$lambda.min
fit_test <- predict(fit, newx=x_test,s=lambda_min)

The error is as follows:

Error in as.matrix(cbind2(1, newx) %*% nbeta) : error in evaluating the argument 'x' in selecting a method for function 'as.matrix': Error in cbind2(1, newx) %*% nbeta : not-yet-implemented method for <data.frame> %*% <dgCMatrix>

I tried debugging, but i am not sure where the

as.matrix(cbind2(1, newx) %*% nbeta)

code is being used and what is causing this error.

like image 481
RDPD Avatar asked Feb 16 '16 16:02

RDPD


1 Answers

Your original data frame has a factor (categorical) variable among the predictor variables. When you use model.matrix it does something sensible with this variable; if you just pass it directly to predict, it doesn't know what to do.

newX <- model.matrix(~.-y,data=x_test)
fit_test<-predict(fit, newx=newX,s=lambda_min)

By the way, you could have replicated this example with a minimal/made-up example, with just a few lines of data ... for example, this setup gives the same error (I called the data dd rather than "data", because the latter is a built-in function in R):

set.seed(101)
dd <- data.frame(y=rnorm(5),
            a=1:5,b=2:6,c=3:7,d=letters[1:5])
model <- model.matrix(y~., data=dd)
n <- nrow(dd)
train <- sample(1:n, size=round(0.8*n))
test <- setdiff(1:n,train)
like image 174
Ben Bolker Avatar answered Sep 18 '22 20:09

Ben Bolker