Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Compute sum_i f(i) x(i) x(i)' fast?

Tags:

julia

I'm trying to compute the summation of f(i) * x(i) * x(i)' where x(i) is a column vector, x(i)' is the transpose, and f(i) is a scalar. So it's a weighted sum of outer products.

In MATLAB, this can be achieved pretty fast by using bsxfun. The following code runs in 260 ms on my laptop (MacBook Air 2010)

N = 1e5;
d = 100;
f = randn(N, 1);
x = randn(N, d);
% H = zeros(d, d);

tic;
H = x' * bsxfun(@times, f, x);
toc

I've been trying to make Julia do the same job, but I can't get to do it faster.

N = int(1e5);
d = 100;
f = randn(N);
x = randn(N, d);

function hess1(x, f)
    N, d = size(x);
    temp = zeros(N, d);
    @simd for kk = 1:N
        @inbounds temp[kk, :] = f[kk] * x[kk, :];
    end
    H = x' * temp;
end

function hess2(x, f)
    N, d = size(x);
    H2 = zeros(d,d);
    @simd for k = 1:N
        @inbounds H2 += f[k] * x[k, :]' * x[k, :];
    end
    return H2
end

function hess3(x, f)
    N, d = size(x);
    H3 = zeros(d,d);
    for k = 1:N
        for k1 = 1:d
            @simd for k2 = 1:d
                @inbounds H3[k1, k2] += x[k, k1] * x[k, k2] * f[k];
            end
        end
    end
    return H3
end

The results are

@time H1 = hess1(x, f);
@time H2 = hess2(x, f);
@time H3 = hess3(x, f);
elapsed time: 0.776116469 seconds (262480224 bytes allocated, 26.49% gc time)
elapsed time: 30.496472345 seconds (16385442496 bytes allocated, 56.07% gc time)
elapsed time: 2.769934563 seconds (80128 bytes allocated)

hess1 is like MATLAB's bsxfun but slower, and hess3 uses no temporary memory, but significantly slower. My best julia code is 3 times slower than MATLAB.

How can I make this julia code faster?

IJulia gist: http://nbviewer.ipython.org/gist/memming/669fb8e78af3338ebf6f

Julia version: 0.3.0-rc1

EDIT: I tested on a more powerful computer (3.5 Ghz Intel i7, 4 core, L2 256kB, L3 8 MB)

  • MATLAB R2014a without -singleCompThread: 0.053 s
  • MATLAB R2014a with -singleCompThread: 0.080 s (@tholy's suggestion)
  • Julia 0.3.0-rc1
    • hess1 elapsed time: 0.215406904 seconds (262498648 bytes allocated, 32.74% gc time)
    • hess2 elapsed time: 10.722578699 seconds (16384080176 bytes allocated, 62.20% gc time)
    • hess3 elapsed time: 1.065504355 seconds (80176 bytes allocated)
    • bsxfunstyle elapsed time: 0.063540168 seconds (80081072 bytes allocated, 25.04% gc time) (@IainDunning's solution)

Indeed, using broadcast is much faster and comparable to MATLAB's bsxfun.

like image 730
Memming Avatar asked Aug 09 '14 21:08

Memming


2 Answers

You are looking for the broadcast function. Here is the relevant issue discussing the functionality and naming.

I implemented your version as well as a broadcast version, here is what I found:

srand(1988)
N = 100_000
d = 100
f = randn(N, 1)
x = randn(N, d)

function hess1(x, f)
    N, d = size(x);
    temp = zeros(N, d);
    @simd for kk = 1:N
        @inbounds temp[kk, :] = f[kk] * x[kk, :];
    end
    H = x' * temp;
end

function bsxfunstyle(x, f)
    x' * broadcast(*,f,x)
end

# Warmup
hess1(x,f)
bsxfunstyle(x, f)

# For real
println("Hess1")
@time H1 = hess1(x, f)
println("Broadcast")
@time H2 = bsxfunstyle(x, f)

# Check solutions are identical
println(sum(abs(H1-H2)))

with output

Hess1
elapsed time: 0.324256216 seconds (262498648 bytes allocated, 33.95% gc time)
Broadcast
elapsed time: 0.126647594 seconds (80080696 bytes allocated, 20.22% gc time)
0.0
like image 95
IainDunning Avatar answered Nov 14 '22 11:11

IainDunning


There are several performance issues with your functions

  • you're creating temporary arrays by x[kk, :].
  • you are traversing matrix in rows while they are stored in column order.
  • You are using x' (which first transpose the matrix) rather than At_mul_B(x,...)

A simple modification gives better performances :

N = 100_000
d = 100
f = randn(N)
x = randn(N, d)
f = randn(N, 1)
x = randn(N, d)

function hess(x, f)
    N, d = size(x);
    temp = zeros(N, d);
    @inbounds for k1 = 1:d
        @simd for kk = 1:N
           temp[kk, k1] = f[kk] * x[kk, k1]
        end
    end
    H = At_mul_B(x, temp)
end
@time hess(x, f)
# 0.067636 seconds (9 allocations: 76.371 MB, 11.24% gc time)
like image 28
Matthew Avatar answered Nov 14 '22 10:11

Matthew