Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Optimisation of 4D tensor rotation

I have to perform the rotation of a 3x3x3x3 4D tensor +100k times per time step in a Stokes solver, where the rotated 4D tensor is Crot[i,j,k,l] = Crot[i,j,k,l] + Q[m,i] * Q[n,j] * Q[o,k] * Q[p,l] * C[m,n,o,p], with all indexes from 1 to 3.

So far I have naively written the following code in Julia:

Q    = rand(3,3)
C    = rand(3,3,3,3)
Crot = Array{Float64}(undef,3,3,3,3)
function rotation_4d!(Crot::Array{Float64,4},Q::Array{Float64,2},C::Array{Float64,4})
aux = 0.0
for i = 1:3
    for j = 1:3
        for k = 1:3
            for l = 1:3

                for m = 1:3
                    for n = 1:3
                        for o = 1:3
                            for p = 1:3                                     
                                aux += Q[m,i] * Q[n,j] * Q[o,k] * Q[p,l] * C[m,n,o,p];
                            end
                        end
                    end
                end

                Crot[i,j,k,l] += aux

            end
        end
    end
end

end

With:

@btime rotation_4d(Crot,Q,C)
14.255 μs (0 allocations: 0 bytes)

Is there any way to optimise the code?

like image 421
AdM Avatar asked Dec 14 '22 08:12

AdM


2 Answers

I timed the various einsum packages. Einsum is faster just by virtue of adding @inbounds. TensorOperations is slower for such small matrices. LoopVectorization takes an age to compile here, but the end result is faster.

(I presume you meant to zero aux once per element, for l = 1:3; aux = 0.0; for m = 1:3, and I set Crot .= 0 so as not to accumulate on top of junk.)

@btime rotation_4d!($Crot,$Q,$C)  # 14.556 μs (0 allocations: 0 bytes)
Crot .= 0; # surely!
rotation_4d!(Crot,Q,C)
res = copy(Crot);

using Einsum # just adds @inbounds really
rot_ei!(Crot,Q,C) = @einsum Crot[i,j,k,l] += Q[m,i] * Q[n,j] * Q[o,k] * Q[p,l] * C[m,n,o,p]
Crot .= 0;
rot_ei!(Crot,Q,C) ≈ res # true
@btime rot_ei!($Crot,$Q,$C);      # 7.445 μs (0 allocations: 0 bytes)

using TensorOperations # sends to BLAS
rot_to!(Crot,Q,C) = @tensor Crot[i,j,k,l] += Q[m,i] * Q[n,j] * Q[o,k] * Q[p,l] * C[m,n,o,p]
Crot .= 0; 
rot_to!(Crot,Q,C) ≈ res # true
@btime rot_to!($Crot,$Q,$C);      # 22.810 μs (106 allocations: 11.16 KiB)

using Tullio, LoopVectorization
rot_lv!(Crot,Q,C) = @tullio Crot[i,j,k,l] += Q[m,i] * Q[n,j] * Q[o,k] * Q[p,l] * C[m,n,o,p]  tensor=false
Crot .= 0; 
@time rot_lv!(Crot,Q,C) ≈ res # 50 seconds!
@btime rot_lv!($Crot,$Q,$C);      # 2.662 μs (8 allocations: 256 bytes)

However, this is still an awful algorithm. It's just 4 small matrix multiplications, but each one gets done many times. Doing them in series is much faster -- 9*4 * 27 multiplications, instead of [corrected!] 4 * 9^4 for the simple nesting above.

function rot2_ein!(Crot, Q, C)
    @einsum mid[m,n,k,l] := Q[o,k] * Q[p,l] * C[m,n,o,p]
    @einsum Crot[i,j,k,l] += Q[m,i] * Q[n,j] * mid[m,n,k,l]
end
Crot .= 0; rot2_ein!(Crot,Q,C) ≈ res # true 
@btime rot2_ein!($Crot, $Q, $C);  # 1.585 μs (2 allocations: 784 bytes)

function rot4_ein!(Crot, Q, C) # overwrites Crot without addition
    @einsum Crot[m,n,o,l] = Q[p,l] * C[m,n,o,p]
    @einsum Crot[m,n,k,l] = Q[o,k] * Crot[m,n,o,l]
    @einsum Crot[m,j,k,l] = Q[n,j] * Crot[m,n,k,l]
    @einsum Crot[i,j,k,l] = Q[m,i] * Crot[m,j,k,l]
end
rot4_ein!(Crot,Q,C) ≈ res # true
@btime rot4_ein!($Crot, $Q, $C);  # 1.006 μs
like image 147
mcabbott Avatar answered Dec 29 '22 08:12

mcabbott


You're doing a lot of indexing here, and therefore a lot of bounds checking. One way to shave off some time here is to use the @inbounds macro, which turns bounds checking off. Rewriting your code as:

function rotation_4d!(Crot::Array{Float64,4},Q::Array{Float64,2},C::Array{Float64,4})
    aux = 0.0
    @inbounds for i = 1:3, j = 1:3, k = 1:3, l = 1:3
        for m = 1:3, n = 1:3, o = 1:3, p = 1:3                                     
            aux += Q[m,i] * Q[n,j] * Q[o,k] * Q[p,l] * C[m,n,o,p];
        end
    Crot[i,j,k,l] += aux

    end
end

gives me a roughly 3x speedup (6μs vs 18μs on my system).

You can read about this in the manual here. Note however that you need to make sure that all your dimensions are correctly sized, which makes working with hardcoded ranges like in your function tricky - consider using some of Julia's builtin iteration syntax (like eachindex) or using size(Q, 1) if you need your loops to change iterations numbers depending on inputs.

like image 36
Nils Gudat Avatar answered Dec 29 '22 08:12

Nils Gudat