Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Julia - Linear combination of row-wise outer products

I have a matrix A of dimension (n, m) and a matrix B of dimension (n, p). For each of the n rows, I would like to compute the outer product between the row of A and the row of B, which are (m, p) matrices. I then have a vector x of size n and I would like to multiply each of these matrices by the corresponding entry of x and sum everything up. How can I do that?

# Parameters
n, m, p = 100, 10, 3
# Matrices & Vectors
A, B, x = randn(n, m), randn(n, p), randn(n)
# Slow method
result = zeros(m, p)
for i in 1:n
    result += x[i] * (A[i, :] * B[i, :]')
end
like image 322
Physics_Student Avatar asked Oct 24 '25 04:10

Physics_Student


1 Answers

General tips to save time and memory:

  1. Put code in a method instead of the global scope, and make sure every variable in that function comes from the arguments, not global variables. That way, Julia's compiler can infer the types of variables and optimize.

  2. Reduce allocations where possible, and you have many opportunities. The changes here distinguish the old_way and new_way methods, and it causes a 5-6x speedup and reduction to 1 allocation.

    • When slicing an array, use @view to avoid default behavior of allocating a copy.
    • You can change result in-place with .+=. += allocates a new array and reassigns the variable result to it.
    • For elementwise operations like x[i] * ..., chaining dotted operators fuses the underlying elementwise loops and reduces allocations of intermediate arrays.
    • A matrix multiplication of a column (Mx1) vector and a row (1xN) vector can be simplified to elementwise multiplication.
n, m, p = 100, 10, 3
A, B, x = randn(n, m), randn(n, p), randn(n)

# Methods below do not use the above global variables

function old_way(A, B, x, n, m, p)
  result = zeros(m, p)
  for i in 1:n
    result += x[i] * (A[i, :] * B[i, :]')
  end
  result
end

function new_way(A, B, x, n, m, p)
  result = zeros(m, p)
  for i in 1:n
    result .+= x[i] .* ( @view(A[i, :]) .* @view(B[i, :])' )
  end
  result
end

using BenchmarkTools

@btime old_way(A, B, x, n, m, p);
# 36.753 μs (501 allocations: 125.33 KiB)

@btime new_way(A, B, x, n, m, p);
# 6.542 μs (1 allocation: 336 bytes)

old_way(A, B, x, n, m, p) == new_way(A, B, x, n, m, p)
# true

The example above avoided global variables so far, and the example below will show why. Even if you put your code in a method but still use global variables, not only is the performance just generally worse, trying to reduce allocations backfires:

# Methods below use n, m, p as global inputs

function old_oops(A, B, x)
  # same code as old_way(A, B, x, n, m, p)
end

function new_oops(A, B, x)
  # same code as new_way(A, B, x, n, m, p)
end

@btime old_oops(A, B, x);
# 95.317 μs (1802 allocations: 153.48 KiB)

@btime new_oops(A, B, x);
# 235.191 μs (1302 allocations: 81.61 KiB)
like image 164
BatWannaBe Avatar answered Oct 27 '25 00:10

BatWannaBe