Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Writing a vector sum function with SIMD (System.Numerics) and making it faster than a for loop

I wrote a function to add up all the elements of a double[] array using SIMD (System.Numerics.Vector) and the performance is worse than the naïve method.

On my computer Vector<double>.Count is 4 which means I could create an accumulator of 4 values and run through the array adding up the elements by groups.

For example a 10 element array, with a 4 element accumulator and 2 remaining elements I would get

//     | loop                  | remainder
acc[0] = vector[0] + vector[4] + vector[8]
acc[1] = vector[1] + vector[5] + vector[9]
acc[2] = vector[2] + vector[6] 
acc[3] = vector[3] + vector[7] 

and the result sum = acc[0]+acc[1]+acc[2]+acc[3]

The code below produces the correct results, but the speed isn't there compared to just a loop adding up the values

public static double SumSimd(this Span<double> a)
    var n = System.Numerics.Vector<double>.Count;
    var count = a.Length;
    // divide array into n=4 element groups
    // Example, 57 = 14*4 + 3
    var groups = Math.DivRem(count, n, out int remain);
    var buffer = new double[n];
    // Create buffer with remaining elements (not in groups)
    a.Slice(groups*n, remain).CopyTo(buffer);
    // Scan through all groups and accumulate
    var accumulator = new System.Numerics.Vector<double>(buffer);
    for (int i = 0; i < groups; i++)
        //var next = new System.Numerics.Vector<double>(a, n * i);
        var next = new System.Numerics.Vector<double>(a.Slice(n * i, n));
        accumulator += next;
    var sum = 0.0;
    // Add up the elements of the accumulator vs
    for (int j = 0; j < n; j++)
        sum += accumulator[j];
    return sum;

So my question is why aren't I realizing any benefits here with SIMD?


The baseline code looks like this

public static double LinAlgSum(this ReadOnlySpan<double> span)
    double sum = 0;
    for (int i = 0; i < span.Length; i++)
        sum += span[i];
    return sum;

In benchmarking the SIMD code comparing to the above, the SIMD code is 5× slower for size=7, 2.5× slower for size=144 and about the same for size=770.

I am running release mode using BenchmarkDotNet. Here is the driver class

public class LinearAlgebraBench
    [Params(7, 35, 77, 144, 195, 311, 722)]
    public int Size { get; set; }

    public void SetupData()
        A = new LinearAlgebra.Vector(Size, (iter) => 2 * Size - iter).ToArray();
        B = new LinearAlgebra.Vector(Size, (iter) => Size/2 + 2* iter).ToArray();

    public double[] A { get; set; }
    public double[] B { get; set; }

    [BenchmarkCategory("Sum"), Benchmark(Baseline = true)]
    public double BenchLinAlgSum()
        return LinearAlgebra.LinearAlgebra.Sum(A.AsSpan().AsReadOnly());
    [BenchmarkCategory("Sum"), Benchmark]
    public double BenchSimdSum()
        return LinearAlgebra.LinearAlgebra.SumSimd(A);
like image 335
JAlex Avatar asked Dec 23 '22 15:12


1 Answers

As per @JonasH answer

It is also worth noting that the compiler does not seem to produce very efficient SIMD code with the generic API.

I disagree. It's only worth to ensure that the method is properly implemented. In some cases - yes, direct using Intrinsics instead of Numerics vector gives a serious boost but not always.

The issue here is measuring very small iteration. Benchmark.NET can't do it in general. The possible solution is wrapping target method in a loop.

As for me, writing a reperesentative benchmark is a hard work and I'm probably not enough good in it. But I'll try.

public class SumTest
    [Params(7, 35, 77, 144, 195, 311, 722)]
    public int Size { get; set; }

    public void SetupData()
        A = Enumerable.Range(0, Size).Select(x => 1.1).ToArray();

    public double[] A { get; set; }

    [BenchmarkCategory("Sum"), Benchmark(Baseline = true)]
    public double BenchScalarSum()
        double result = 0;
        for (int i = 0; i < 10000; i++)
            result = SumScalar(A);
        return result;

    [BenchmarkCategory("Sum"), Benchmark]
    public double BenchNumericsSum()
        double result = 0;
        for (int i = 0; i < 10000; i++)
            result = SumNumerics(A);
        return result;

    [BenchmarkCategory("Sum"), Benchmark]
    public double BenchIntrinsicsSum()
        double result = 0;
        for (int i = 0; i < 10000; i++)
            result = SumIntrinsics(A);
        return result;

    public double SumScalar(ReadOnlySpan<double> numbers)
        double result = 0;
        for (int i = 0; i < numbers.Length; i++)
            result += numbers[i];
        return result;

    public double SumNumerics(ReadOnlySpan<double> numbers)
        ReadOnlySpan<Vector<double>> vectors = MemoryMarshal.Cast<double, Vector<double>>(numbers);
        Vector<double> acc = Vector<double>.Zero;
        for (int i = 0; i < vectors.Length; i++)
            acc += vectors[i];
        double result = Vector.Dot(acc, Vector<double>.One);
        for (int i = vectors.Length * Vector<double>.Count; i < numbers.Length; i++)
            result += numbers[i];
        return result;

    public double SumIntrinsics(ReadOnlySpan<double> numbers)
        ReadOnlySpan<Vector256<double>> vectors = MemoryMarshal.Cast<double, Vector256<double>>(numbers);
        Vector256<double> acc = Vector256<double>.Zero;
        for (int i = 0; i < vectors.Length; i++)
            acc = Avx.Add(acc, vectors[i]);
        Vector128<double> r = Sse2.Add(acc.GetUpper(), acc.GetLower());
        double result = Sse3.HorizontalAdd(r, r).GetElement(0); // I'm aware that VHADDPD probably not enough efficient but leaving it for simplicity here
        for (int i = vectors.Length * Vector256<double>.Count; i < numbers.Length; i++)
            result += numbers[i];
        return result;
BenchmarkDotNet=v0.12.1, OS=Windows 10.0.19042
Intel Core i7-4700HQ CPU 2.40GHz (Haswell), 1 CPU, 8 logical and 4 physical cores
.NET Core SDK=5.0.203
  [Host]     : .NET Core 5.0.6 (CoreCLR 5.0.621.22011, CoreFX 5.0.621.22011), X64 RyuJIT
  Job-NQCIIR : .NET Core 5.0.6 (CoreCLR 5.0.621.22011, CoreFX 5.0.621.22011), X64 RyuJIT
Method Size Mean Error StdDev Median Ratio RatioSD
BenchScalarSum 7 53.34 us 0.056 us 0.050 us 53.30 us 1.00 0.00
BenchNumericsSum 7 48.95 us 2.262 us 6.671 us 44.95 us 0.95 0.10
BenchIntrinsicsSum 7 55.85 us 2.089 us 6.128 us 51.90 us 1.07 0.10
BenchScalarSum 35 258.46 us 2.319 us 3.541 us 257.00 us 1.00 0.00
BenchNumericsSum 35 94.14 us 1.989 us 5.705 us 91.00 us 0.36 0.02
BenchIntrinsicsSum 35 90.82 us 2.465 us 7.073 us 92.10 us 0.35 0.03
BenchScalarSum 77 541.18 us 10.401 us 11.129 us 536.95 us 1.00 0.00
BenchNumericsSum 77 161.05 us 3.171 us 7.475 us 159.30 us 0.30 0.01
BenchIntrinsicsSum 77 153.19 us 3.063 us 7.906 us 150.50 us 0.29 0.02
BenchScalarSum 144 1,166.72 us 6.945 us 5.422 us 1,166.10 us 1.00 0.00
BenchNumericsSum 144 294.72 us 5.675 us 10.520 us 292.50 us 0.26 0.01
BenchIntrinsicsSum 144 287.18 us 5.661 us 13.671 us 284.20 us 0.25 0.01
BenchScalarSum 195 1,671.83 us 32.634 us 34.918 us 1,663.30 us 1.00 0.00
BenchNumericsSum 195 443.19 us 7.916 us 11.354 us 443.10 us 0.26 0.01
BenchIntrinsicsSum 195 444.21 us 8.876 us 7.868 us 443.55 us 0.27 0.01
BenchScalarSum 311 2,742.78 us 35.797 us 29.892 us 2,745.70 us 1.00 0.00
BenchNumericsSum 311 778.00 us 34.173 us 100.759 us 719.20 us 0.30 0.04
BenchIntrinsicsSum 311 776.30 us 29.304 us 86.404 us 727.45 us 0.29 0.03
BenchScalarSum 722 6,607.72 us 79.263 us 74.143 us 6,601.20 us 1.00 0.00
BenchNumericsSum 722 1,870.81 us 43.390 us 127.936 us 1,850.30 us 0.28 0.02
BenchIntrinsicsSum 722 1,867.57 us 39.718 us 117.110 us 1,851.50 us 0.28 0.02

Looks like using Vectors at least not less efficient than the baseline method.

As a bonus, let's look at the output assembly code using https://sharplab.io/ (x64)

    L0000: vzeroupper
    L0003: mov rax, [rdx]
    L0006: mov edx, [rdx+8]
    L0009: vxorps xmm0, xmm0, xmm0
    L000d: xor ecx, ecx
    L000f: test edx, edx
    L0011: jle short L0022
    L0013: movsxd r8, ecx
    L0016: vaddsd xmm0, xmm0, [rax+r8*8]
    L001c: inc ecx
    L001e: cmp ecx, edx
    L0020: jl short L0013
    L0022: ret

    L0000: sub rsp, 0x28
    L0004: vzeroupper
    L0007: mov rax, [rdx]
    L000a: mov edx, [rdx+8]
    L000d: mov ecx, edx
    L000f: shl rcx, 3
    L0013: shr rcx, 5
    L0017: cmp rcx, 0x7fffffff
    L001e: ja short L0078
    L0020: vxorps ymm0, ymm0, ymm0
    L0024: xor r8d, r8d
    L0027: test ecx, ecx
    L0029: jle short L0040
    L002b: movsxd r9, r8d
    L002e: shl r9, 5
    L0032: vaddpd ymm0, ymm0, [rax+r9]
    L0038: inc r8d
    L003b: cmp r8d, ecx
    L003e: jl short L002b
    L0040: vmulpd ymm0, ymm0, [SumTest.SumNumerics(System.ReadOnlySpan`1<Double>)]
    L0048: vhaddpd ymm0, ymm0, ymm0
    L004c: vextractf128 xmm1, ymm0, 1
    L0052: vaddpd xmm0, xmm0, xmm1
    L0056: shl ecx, 2
    L0059: cmp ecx, edx
    L005b: jge short L0070
    L005d: cmp ecx, edx
    L005f: jae short L007e
    L0061: movsxd r8, ecx
    L0064: vaddsd xmm0, xmm0, [rax+r8*8]
    L006a: inc ecx
    L006c: cmp ecx, edx
    L006e: jl short L005d
    L0070: vzeroupper
    L0073: add rsp, 0x28
    L0077: ret
    L0078: call 0x00007ffc9de2b710
    L007d: int3
    L007e: call 0x00007ffc9de2bc70
    L0083: int3

    L0000: sub rsp, 0x28
    L0004: vzeroupper
    L0007: mov rax, [rdx]
    L000a: mov edx, [rdx+8]
    L000d: mov ecx, edx
    L000f: shl rcx, 3
    L0013: shr rcx, 5
    L0017: cmp rcx, 0x7fffffff
    L001e: ja short L0070
    L0020: vxorps ymm0, ymm0, ymm0
    L0024: xor r8d, r8d
    L0027: test ecx, ecx
    L0029: jle short L0040
    L002b: movsxd r9, r8d
    L002e: shl r9, 5
    L0032: vaddpd ymm0, ymm0, [rax+r9]
    L0038: inc r8d
    L003b: cmp r8d, ecx
    L003e: jl short L002b
    L0040: vextractf128 xmm1, ymm0, 1
    L0046: vaddpd xmm0, xmm1, xmm0
    L004a: vhaddpd xmm0, xmm0, xmm0
    L004e: shl ecx, 2
    L0051: cmp ecx, edx
    L0053: jge short L0068
    L0055: cmp ecx, edx
    L0057: jae short L0076
    L0059: movsxd r8, ecx
    L005c: vaddsd xmm0, xmm0, [rax+r8*8]
    L0062: inc ecx
    L0064: cmp ecx, edx
    L0066: jl short L0055
    L0068: vzeroupper
    L006b: add rsp, 0x28
    L006f: ret
    L0070: call 0x00007ffc9de2b710
    L0075: int3
    L0076: call 0x00007ffc9de2bc70
    L007b: int3

Here you can see that JIT produces almost the same code for Vector<T> as for Vector256<T>.

like image 120
aepot Avatar answered May 13 '23 06:05
