Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does X.dot(X.T) require so much memory in numpy?

X is a n x p matrix where p is much larger than n. Let's say n = 1000 and p = 500000. When I run:

X = np.random.randn(1000,500000)
S = X.dot(X.T)

Performing this operation ends up taking a great deal of memory despite the result being of size 1000 x 1000. The memory use goes back down once the operation is finished. Is there any way around this?

like image 461
bluecat Avatar asked Jan 18 '14 18:01

bluecat


1 Answers

The issue is not that X and X.T are views of the same memory space per se, but rather that X.T is F-contiguous rather than C-contiguous. Of course, this must necessarily be true for at least one of the input arrays in the case where you're multiplying an array with a view of its transpose.

In numpy < 1.8, np.dot will create a C-ordered copy of any F-ordered input arrays, not just ones that happen to be views onto the same block of memory.

For example:

X = np.random.randn(1000,50000)
Y = np.random.randn(50000, 100)

# X and Y are both C-order, no copy
%memit np.dot(X, Y)
# maximum of 1: 485.554688 MB per loop

# make X Fortran order and Y C-order, now the larger array (X) gets
# copied
X = np.asfortranarray(X)
%memit np.dot(X, Y)
# maximum of 1: 867.070312 MB per loop

# make X C-order and  Y Fortran order, now the smaller array (Y) gets
# copied
X = np.ascontiguousarray(X)
Y = np.asfortranarray(Y)
%memit np.dot(X, Y)
# maximum of 1: 523.792969 MB per loop

# make both of them F-ordered, both get copied!
X = np.asfortranarray(X)
%memit np.dot(X, Y)
# maximum of 1: 905.093750 MB per loop

If copying is a problem (e.g. when X is very large), what can you do about it?

The best option would probably be to upgrade to a newer version of numpy - as @perimosocordiae points out, this performance issue was addressed in this pull request.

If for whatever reason you can't upgrade numpy, there is also a trick that allows you to perform fast, BLAS-based dot products without forcing a copy by calling the relevant BLAS function directly through scipy.linalg.blas (shamelessly stolen from this answer):

from scipy.linalg import blas
X = np.random.randn(1000,50000)

%memit res1 = np.dot(X, X.T)
# maximum of 1: 845.367188 MB per loop

%memit res2 = blas.dgemm(alpha=1., a=X.T, b=X.T, trans_a=True)
# maximum of 1: 471.656250 MB per loop

print np.all(res1 == res2)
# True
like image 172
ali_m Avatar answered Nov 12 '22 13:11

ali_m