GADGET-4
myalltoall.cc
Go to the documentation of this file.
1/*******************************************************************************
2 * \copyright This file is part of the GADGET4 N-body/SPH code developed
3 * \copyright by Volker Springel. Copyright (C) 2014-2020 by Volker Springel
4 * \copyright (vspringel@mpa-garching.mpg.de) and all contributing authors.
5 ******************************************************************************/
6
12#include "gadgetconfig.h"
13
14#include <math.h>
15#include <mpi.h>
16#include <stdio.h>
17#include <stdlib.h>
18#include <string.h>
19
20#include "../data/allvars.h"
21#include "../data/dtypes.h"
22#include "../data/mymalloc.h"
23#include "../mpi_utils/mpi_utils.h"
24
25#define PCHAR(a) ((char *)a)
26
27/* This method prepares an Alltoallv computation.
28 sendcnt: must have as many entries as there are Tasks in comm
29 must be set
30 recvcnt: must have as many entries as there are Tasks in comm
31 will be set on return
32 rdispls: must have as many entries as there are Tasks in comm, or be NULL
33 if not NULL, will be set on return
34 method: use standard Alltoall() approach or one-sided approach
35 returns: number of entries needed in the recvbuf */
36int myMPI_Alltoallv_new_prep(int *sendcnt, int *recvcnt, int *rdispls, MPI_Comm comm, int method)
37{
38 int rank, nranks;
39 MPI_Comm_size(comm, &nranks);
40 MPI_Comm_rank(comm, &rank);
41
42 if(method == 0 || method == 1)
43 myMPI_Alltoall(sendcnt, 1, MPI_INT, recvcnt, 1, MPI_INT, comm);
44 else if(method == 10)
45 {
46 for(int i = 0; i < nranks; ++i)
47 recvcnt[i] = 0;
48 recvcnt[rank] = sendcnt[rank]; // local communication
49 MPI_Win win;
50 MPI_Win_create(recvcnt, nranks * sizeof(MPI_INT), sizeof(MPI_INT), MPI_INFO_NULL, comm, &win);
51 MPI_Win_fence(0, win);
52 for(int i = 1; i < nranks; ++i) // remote communication
53 {
54 int tgt = (rank + i) % nranks;
55 if(sendcnt[tgt] != 0)
56 MPI_Put(&sendcnt[tgt], 1, MPI_INT, tgt, rank, 1, MPI_INT, win);
57 }
58 MPI_Win_fence(0, win);
59 MPI_Win_free(&win);
60 }
61 else
62 Terminate("bad communication method");
63
64 int total = 0;
65 for(int i = 0; i < nranks; ++i)
66 {
67 if(rdispls)
68 rdispls[i] = total;
69 total += recvcnt[i];
70 }
71 return total;
72}
73
74void myMPI_Alltoallv_new(void *sendbuf, int *sendcnt, int *sdispls, MPI_Datatype sendtype, void *recvbuf, int *recvcnt, int *rdispls,
75 MPI_Datatype recvtype, MPI_Comm comm, int method)
76{
77 int rank, nranks, itsz;
78 MPI_Comm_size(comm, &nranks);
79 MPI_Comm_rank(comm, &rank);
80 MPI_Type_size(sendtype, &itsz);
81 size_t tsz = itsz; // to enforce size_t data type in later computations
82
83#ifdef MPI_HYPERCUBE_ALLTOALL
84 method = 1;
85#endif
86
87 if(method == 0) // standard Alltoallv
88 MPI_Alltoallv(sendbuf, sendcnt, sdispls, sendtype, recvbuf, recvcnt, rdispls, recvtype, comm);
89 else if(method == 1) // blocking sendrecv
90 {
91 if(sendtype != recvtype)
92 Terminate("bad MPI communication types");
93 int lptask = 1;
94 while(lptask < nranks)
95 lptask <<= 1;
96 int tag = 42;
97
98 if(recvcnt[rank] > 0) // local communication
99 memcpy(PCHAR(recvbuf) + tsz * rdispls[rank], PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
100
101 for(int ngrp = 1; ngrp < lptask; ngrp++)
102 {
103 int otask = rank ^ ngrp;
104 if(otask < nranks)
105 if(sendcnt[otask] > 0 || recvcnt[otask] > 0)
106 myMPI_Sendrecv(PCHAR(sendbuf) + tsz * sdispls[otask], sendcnt[otask], sendtype, otask, tag,
107 PCHAR(recvbuf) + tsz * rdispls[otask], recvcnt[otask], recvtype, otask, tag, comm, MPI_STATUS_IGNORE);
108 }
109 }
110 else if(method == 2) // asynchronous communication
111 {
112 if(sendtype != recvtype)
113 Terminate("bad MPI communication types");
114 int lptask = 1;
115 while(lptask < nranks)
116 lptask <<= 1;
117 int tag = 42;
118
119 MPI_Request *requests = (MPI_Request *)Mem.mymalloc("requests", 2 * nranks * sizeof(MPI_Request));
120 int n_requests = 0;
121
122 if(recvcnt[rank] > 0) // local communication
123 memcpy(PCHAR(recvbuf) + tsz * rdispls[rank], PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
124
125 for(int ngrp = 1; ngrp < lptask; ngrp++)
126 {
127 int otask = rank ^ ngrp;
128 if(otask < nranks)
129 if(recvcnt[otask] > 0)
130 MPI_Irecv(PCHAR(recvbuf) + tsz * rdispls[otask], recvcnt[otask], recvtype, otask, tag, comm, &requests[n_requests++]);
131 }
132
133 for(int ngrp = 1; ngrp < lptask; ngrp++)
134 {
135 int otask = rank ^ ngrp;
136 if(otask < nranks)
137 if(sendcnt[otask] > 0)
138 MPI_Issend(PCHAR(sendbuf) + tsz * sdispls[otask], sendcnt[otask], sendtype, otask, tag, comm, &requests[n_requests++]);
139 }
140
141 MPI_Waitall(n_requests, requests, MPI_STATUSES_IGNORE);
142 Mem.myfree(requests);
143 }
144 else if(method == 10)
145 {
146 if(sendtype != recvtype)
147 Terminate("bad MPI communication types");
148 int *disp_at_sender = (int *)Mem.mymalloc("disp_at_sender", nranks * sizeof(int));
149 disp_at_sender[rank] = sdispls[rank];
150 MPI_Win win;
151 // TODO:supply info object with "no_lock"
152 MPI_Win_create(sdispls, nranks * sizeof(MPI_INT), sizeof(MPI_INT), MPI_INFO_NULL, comm, &win);
153 MPI_Win_fence(0, win);
154 for(int i = 1; i < nranks; ++i)
155 {
156 int tgt = (rank + i) % nranks;
157 if(recvcnt[tgt] != 0)
158 MPI_Get(&disp_at_sender[tgt], 1, MPI_INT, tgt, rank, 1, MPI_INT, win);
159 }
160 MPI_Win_fence(0, win);
161 MPI_Win_free(&win);
162 if(recvcnt[rank] > 0) // first take care of local communication
163 memcpy(PCHAR(recvbuf) + tsz * rdispls[rank], PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
164 MPI_Win_create(sendbuf, (sdispls[nranks - 1] + sendcnt[nranks - 1]) * tsz, tsz, MPI_INFO_NULL, comm, &win);
165 MPI_Win_fence(0, win);
166 for(int i = 1; i < nranks; ++i) // now the rest, start with right neighbour
167 {
168 int tgt = (rank + i) % nranks;
169 if(recvcnt[tgt] != 0)
170 MPI_Get(PCHAR(recvbuf) + tsz * rdispls[tgt], recvcnt[tgt], sendtype, tgt, disp_at_sender[tgt], recvcnt[tgt], sendtype,
171 win);
172 }
173 MPI_Win_fence(0, win);
174 MPI_Win_free(&win);
175 Mem.myfree(disp_at_sender);
176 }
177 else
178 Terminate("bad communication method");
179}
180
181void myMPI_Alltoallv(void *sendb, size_t *sendcounts, size_t *sdispls, void *recvb, size_t *recvcounts, size_t *rdispls, int len,
182 int big_flag, MPI_Comm comm)
183{
184 char *sendbuf = (char *)sendb;
185 char *recvbuf = (char *)recvb;
186
187#ifndef MPI_HYPERCUBE_ALLTOALL
188 if(big_flag == 0)
189 {
190 int ntask;
191 MPI_Comm_size(comm, &ntask);
192
193 int *scount = (int *)Mem.mymalloc("scount", ntask * sizeof(int));
194 int *rcount = (int *)Mem.mymalloc("rcount", ntask * sizeof(int));
195 int *soff = (int *)Mem.mymalloc("soff", ntask * sizeof(int));
196 int *roff = (int *)Mem.mymalloc("roff", ntask * sizeof(int));
197
198 for(int i = 0; i < ntask; i++)
199 {
200 scount[i] = sendcounts[i] * len;
201 rcount[i] = recvcounts[i] * len;
202 soff[i] = sdispls[i] * len;
203 roff[i] = rdispls[i] * len;
204 }
205
206 MPI_Alltoallv(sendbuf, scount, soff, MPI_BYTE, recvbuf, rcount, roff, MPI_BYTE, comm);
207
208 Mem.myfree(roff);
209 Mem.myfree(soff);
210 Mem.myfree(rcount);
211 Mem.myfree(scount);
212 }
213 else
214#endif
215 {
216 /* here we definitely have some large messages. We default to the
217 * pair-wise protocol, which should be most robust anyway.
218 */
219 int ntask, thistask, ptask;
220 MPI_Comm_size(comm, &ntask);
221 MPI_Comm_rank(comm, &thistask);
222
223 for(ptask = 0; ntask > (1 << ptask); ptask++)
224 ;
225
226 for(int ngrp = 0; ngrp < (1 << ptask); ngrp++)
227 {
228 int target = thistask ^ ngrp;
229
230 if(target < ntask)
231 {
232 if(sendcounts[target] > 0 || recvcounts[target] > 0)
233 myMPI_Sendrecv(sendbuf + sdispls[target] * len, sendcounts[target] * len, MPI_BYTE, target, TAG_PDATA + ngrp,
234 recvbuf + rdispls[target] * len, recvcounts[target] * len, MPI_BYTE, target, TAG_PDATA + ngrp, comm,
235 MPI_STATUS_IGNORE);
236 }
237 }
238 }
239}
240
241void my_int_MPI_Alltoallv(void *sendb, int *sendcounts, int *sdispls, void *recvb, int *recvcounts, int *rdispls, int len,
242 int big_flag, MPI_Comm comm)
243{
244 char *sendbuf = (char *)sendb;
245 char *recvbuf = (char *)recvb;
246
247#ifndef MPI_HYPERCUBE_ALLTOALL
248 if(big_flag == 0)
249 {
250 int ntask;
251 MPI_Comm_size(comm, &ntask);
252
253 int *scount = (int *)Mem.mymalloc("scount", ntask * sizeof(int));
254 int *rcount = (int *)Mem.mymalloc("rcount", ntask * sizeof(int));
255 int *soff = (int *)Mem.mymalloc("soff", ntask * sizeof(int));
256 int *roff = (int *)Mem.mymalloc("roff", ntask * sizeof(int));
257
258 for(int i = 0; i < ntask; i++)
259 {
260 scount[i] = sendcounts[i] * len;
261 rcount[i] = recvcounts[i] * len;
262 soff[i] = sdispls[i] * len;
263 roff[i] = rdispls[i] * len;
264 }
265
266 MPI_Alltoallv(sendbuf, scount, soff, MPI_BYTE, recvbuf, rcount, roff, MPI_BYTE, comm);
267
268 Mem.myfree(roff);
269 Mem.myfree(soff);
270 Mem.myfree(rcount);
271 Mem.myfree(scount);
272 }
273 else
274#endif
275 {
276 /* here we definitely have some large messages. We default to the
277 * pair-wise protocoll, which should be most robust anyway.
278 */
279 int ntask, thistask, ptask;
280 MPI_Comm_size(comm, &ntask);
281 MPI_Comm_rank(comm, &thistask);
282
283 for(ptask = 0; ntask > (1 << ptask); ptask++)
284 ;
285
286 for(int ngrp = 0; ngrp < (1 << ptask); ngrp++)
287 {
288 int target = thistask ^ ngrp;
289
290 if(target < ntask)
291 {
292 if(sendcounts[target] > 0 || recvcounts[target] > 0)
293 myMPI_Sendrecv(sendbuf + sdispls[target] * len, sendcounts[target] * len, MPI_BYTE, target, TAG_PDATA + ngrp,
294 recvbuf + rdispls[target] * len, recvcounts[target] * len, MPI_BYTE, target, TAG_PDATA + ngrp, comm,
295 MPI_STATUS_IGNORE);
296 }
297 }
298 }
299}
300
301int myMPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype,
302 MPI_Comm comm)
303
304{
305#ifndef MPI_HYPERCUBE_ALLTOALL
306 return MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm);
307#else
308 int ntask, ptask, thistask, size_sendtype, size_recvtype;
309
310 MPI_Comm_rank(comm, &thistask);
311 MPI_Comm_size(comm, &ntask);
312
313 MPI_Type_size(sendtype, &size_sendtype);
314 MPI_Type_size(recvtype, &size_recvtype);
315
316 for(ptask = 0; ntask > (1 << ptask); ptask++)
317 ;
318
319 for(int ngrp = 1; ngrp < (1 << ptask); ngrp++)
320 {
321 int recvtask = thistask ^ ngrp;
322
323 if(recvtask < ntask)
324 myMPI_Sendrecv((char *)sendbuf + recvtask * sendcount * size_sendtype, sendcount, sendtype, recvtask, TAG_PDATA + ngrp,
325 (char *)recvbuf + recvtask * recvcount * size_recvtype, recvcount, recvtype, recvtask, TAG_PDATA + ngrp, comm,
326 MPI_STATUS_IGNORE);
327 }
328
329 memcpy((char *)recvbuf + thistask * recvcount * size_recvtype, (char *)sendbuf + thistask * sendcount * size_sendtype,
330 sendcount * size_sendtype);
331
332 return 0;
333
334#endif
335}
#define Terminate(...)
Definition: macros.h:15
#define TAG_PDATA
Definition: mpi_utils.h:27
int myMPI_Sendrecv(void *sendbuf, size_t sendcount, MPI_Datatype sendtype, int dest, int sendtag, void *recvbuf, size_t recvcount, MPI_Datatype recvtype, int source, int recvtag, MPI_Comm comm, MPI_Status *status)
void my_int_MPI_Alltoallv(void *sendb, int *sendcounts, int *sdispls, void *recvb, int *recvcounts, int *rdispls, int len, int big_flag, MPI_Comm comm)
Definition: myalltoall.cc:241
int myMPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
Definition: myalltoall.cc:301
void myMPI_Alltoallv(void *sendb, size_t *sendcounts, size_t *sdispls, void *recvb, size_t *recvcounts, size_t *rdispls, int len, int big_flag, MPI_Comm comm)
Definition: myalltoall.cc:181
int myMPI_Alltoallv_new_prep(int *sendcnt, int *recvcnt, int *rdispls, MPI_Comm comm, int method)
Definition: myalltoall.cc:36
void myMPI_Alltoallv_new(void *sendbuf, int *sendcnt, int *sdispls, MPI_Datatype sendtype, void *recvbuf, int *recvcnt, int *rdispls, MPI_Datatype recvtype, MPI_Comm comm, int method)
Definition: myalltoall.cc:74
#define PCHAR(a)
Definition: myalltoall.cc:25
memory Mem
Definition: main.cc:44