Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cannot understand MPI_Reduce_scatter in MPI

I am trying to understand the MPI_Reduce_scatter function but it seems that my deductions are always wrong :(
The documentation says (link):

MPI_Reduce_scatter first does an element-wise reduction on vector of count = S(i)recvcounts[i] elements in the send buffer defined by sendbuf, count, and datatype. Next, the resulting vector of results is split into n disjoint segments, where n is the number of processes in the group. Segment i contains recvcounts[i] elements. The ith segment is sent to process i and stored in the receive buffer defined by recvbuf, recvcounts[i], and datatype.

I have the following (very simple) C program and I expected to get the max of the first recvcounts[i] elements, but it seems that I am doing something wrong...

#include <stdio.h>
#include <stdlib.h>
#include "mpi.h"

#define NUM_PE 5
#define NUM_ELEM 3

char *print(int arr[], int n);

int main(int argc, char *argv[]) {
    int rank, size, i, n;
    int sendbuf[5][3] = {
        {  1,  2,  3 },
        {  4,  5,  6 },
        {  7,  8,  9 },
        { 10, 11, 12 },
        { 13, 14, 15 }
    };
    int recvbuf[15] = {0};
    int recvcounts[5] = {
        3, 3, 3, 3, 3
    };

    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    n = sizeof(sendbuf[rank]) / sizeof(int);
    printf("sendbuf (thread %d): %s\n", rank, print(sendbuf[rank], n));

    MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, MPI_MAX, MPI_COMM_WORLD);

    n = sizeof(recvbuf) / sizeof(int);
    printf("recvbuf (thread %d): %s\n", rank, print(recvbuf, n)); // <--- I receive the same output as with sendbuf :(

    MPI_Finalize();

    return 0;
}

char *print(int arr[], int n) { } // it returns a string formatted as the following output

The output of my program is the same for recvbuf and sendbuf. I expected recvbuf to contain the max:

$ mpicc 03_reduce_scatter.c
$ mpirun -n 5 ./a.out
sendbuf (thread 4): [ 13, 14, 15 ]
sendbuf (thread 3): [ 10, 11, 12 ]
sendbuf (thread 2): [  7,  8,  9 ]
sendbuf (thread 0): [  1,  2,  3 ]
sendbuf (thread 1): [  4,  5,  6 ]
recvbuf (thread 1): [  4,  5,  6,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
recvbuf (thread 2): [  7,  8,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
recvbuf (thread 0): [  1,  2,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
recvbuf (thread 3): [ 10, 11, 12,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
recvbuf (thread 4): [ 13, 14, 15,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0 ]
like image 317
StockBreak Avatar asked Sep 10 '14 21:09

StockBreak


1 Answers

Yeah, the documentation for Reduce_scatter is terse, and it's not super-widely used, so there aren't a lot of examples. The first couple of slides from this OCW MIT lecture have a nice diagram, and suggest a use case.

The key, as often the case, is to read the MPI document and pay particular attention to the advice to implementers:

"The MPI_REDUCE_SCATTER routine is functionally equivalent to: an MPI_REDUCE collective operation with count equal to the sum of recvcounts[i] followed by MPI_SCATTERV with sendcounts equal to recvcounts."

So let's walk your example through: this line,

MPI_Reduce_scatter(sendbuf, recvbuf, recvcounts, MPI_INT, MPI_MAX, MPI_COMM_WORLD);

would be the equivalent of this:

int totcounts = 15;  // = sum of {3, 3, 3, 3, 3}
MPI_Reduce({1,2,3...15}, tmpbuffer, totcounts, MPI_INT, MPI_MAX, 0, MPI_COMM_WORLD);
MPI_Scatterv(tmpbuffer, recvcounts, [displacements corresponding to recvcounts], 
              MPI_INT, rcvbuffer, 3, MPI_INT, 0, MPI_COMM_WORLD);

So everyone is going to send in the same numbers {1...15}, and each column of those are going to get max'ed against each other, resulting in { max(1,1...1), max(2,2...2) ... max(15,15...15)} = {1,2,...15}.

Then those are going to be scattered to the processors, 3 at a time, resulting in {1,2,3}, {4,5,6}, {7,8,9}...

So that's what does happen, how do we get what you want to happen to happen? I understand that you want each row to get max'ed, and each processor to get "its" corresponding row-max. Eg, let's say the data looks like this:

Proc 0: 1 5 9 13
Proc 1: 2 6 10 14
Proc 2: 3 7 11 15
Proc 3: 4 8 12 16

and we want to end with Proc 0 (say) having the max of all the 0'th pieces of data, proc 1 to have the max of all the 1th, etc, so we'd end up with

Proc 0: 4
Proc 1: 8
Proc 2: 12
Proc 3: 16

So let's see how to do that. First, everyone's going to have one value, so all the recvcounts are 1. Secondly, each process is going to have to send separate data. So we'll have something that looks like this:

#include <stdio.h>
#include <stdlib.h>
#include "mpi.h"

int main(int argc, char *argv[]) {
    int rank, size, i, n;

    MPI_Init(&argc, &argv);
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &size);

    int sendbuf[size];
    int recvbuf;

    for (int i=0; i<size; i++)
        sendbuf[i] = 1 + rank + size*i;

    printf("Proc %d: ", rank);
    for (int i=0; i<size; i++) printf("%d ", sendbuf[i]);
    printf("\n");

    int recvcounts[size];
    for (int i=0; i<size; i++)
        recvcounts[i] = 1;

    MPI_Reduce_scatter(sendbuf, &recvbuf, recvcounts, MPI_INT, MPI_MAX, MPI_COMM_WORLD);

    printf("Proc %d: %d\n", rank, recvbuf);

    MPI_Finalize();

    return 0;
}

Running gives (output reordered for clarity):

Proc 0: 1 5 9 13 
Proc 1: 2 6 10 14 
Proc 2: 3 7 11 15
Proc 3: 4 8 12 16

Proc 0: 4
Proc 1: 8
Proc 2: 12
Proc 3: 16
like image 124
Jonathan Dursi Avatar answered Sep 28 '22 13:09

Jonathan Dursi