Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

compact/efficient replacement for diag(X V X^T)?

When making predictions for a linear statistical model we usually have a model matrix X of predictors corresponding to the points at which we want to make predictions; a vector of coefficients beta; and a variance-covariance matrix V. Computing the predictions is just X %*% beta. The most straightforward way to compute the variances of the predictions is

diag(X %*% V %*% t(X))

or slightly more efficiently

diag(X %*% tcrossprod(V,X))

However, this is very inefficient, because it constructs an n*n matrix when all we really want is the diagonal. I know I could write some Rcpp-loopy thing that would compute just the diagonal terms, but I'm wondering if there is an existing linear algebra trick in R that will nicely do what I want ... (if someone wants to write the Rcpp-loopy thing for me as an answer I wouldn't object, but I'd prefer a pure-R solution)

FWIW predict.lm seems to do something clever by multiplying X by the inverse of the R component of the QR-decomposition of the lm; I'm not sure that's always going to be available, but it might be a good starting point (see here)

like image 500
Ben Bolker Avatar asked Jul 10 '19 01:07

Ben Bolker


1 Answers

Along the lines of this Octave/Matlab question, for two matrices A and B, we can use the use the fact that the nth diagonal entry of AB will be the product of the nth row of A with the nth column of B. We can naively extend that to the case of three matrices, ABC. I have not considered how to optimize in the case where C=A^T, but aside from that, this code looks like promising speedup:

start_time <- Sys.time()

A=matrix(1:1000000, nrow = 1000, ncol = 1000)
B=matrix(1000000:1, nrow = 1000, ncol = 1000)

# Try one of these two
res=diag(A %*% B %*% t(A)) # ~0.47s
res=rowSums(A * t(B %*% t(A))) # ~0.27s

end_time <- Sys.time()

print(end_time - start_time)

Using tcrossprod did not appear to accelerate the results when I ran this code. However, just using the row-sum-dot-product approach appears to be a lot more efficient already, at least on this silly example, which suggests (though I'm not sure) that rowSums is not computing the full intermediate matrices before returning the diagonal entries, as I'd expect happens with diag.

like image 86
davewy Avatar answered Nov 15 '22 04:11

davewy