12#include "gadgetconfig.h"
20#include "../data/allvars.h"
21#include "../data/dtypes.h"
22#include "../data/mymalloc.h"
23#include "../mpi_utils/mpi_utils.h"
25#define PCHAR(a) ((char *)a)
39 MPI_Comm_size(comm, &nranks);
40 MPI_Comm_rank(comm, &rank);
42 if(method == 0 || method == 1)
46 for(
int i = 0; i < nranks; ++i)
48 recvcnt[rank] = sendcnt[rank];
50 MPI_Win_create(recvcnt, nranks *
sizeof(MPI_INT),
sizeof(MPI_INT), MPI_INFO_NULL, comm, &win);
51 MPI_Win_fence(0, win);
52 for(
int i = 1; i < nranks; ++i)
54 int tgt = (rank + i) % nranks;
56 MPI_Put(&sendcnt[tgt], 1, MPI_INT, tgt, rank, 1, MPI_INT, win);
58 MPI_Win_fence(0, win);
65 for(
int i = 0; i < nranks; ++i)
74void myMPI_Alltoallv_new(
void *sendbuf,
int *sendcnt,
int *sdispls, MPI_Datatype sendtype,
void *recvbuf,
int *recvcnt,
int *rdispls,
75 MPI_Datatype recvtype, MPI_Comm comm,
int method)
77 int rank, nranks, itsz;
78 MPI_Comm_size(comm, &nranks);
79 MPI_Comm_rank(comm, &rank);
80 MPI_Type_size(sendtype, &itsz);
83#ifdef MPI_HYPERCUBE_ALLTOALL
88 MPI_Alltoallv(sendbuf, sendcnt, sdispls, sendtype, recvbuf, recvcnt, rdispls, recvtype, comm);
91 if(sendtype != recvtype)
94 while(lptask < nranks)
99 memcpy(
PCHAR(recvbuf) + tsz * rdispls[rank],
PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
101 for(
int ngrp = 1; ngrp < lptask; ngrp++)
103 int otask = rank ^ ngrp;
105 if(sendcnt[otask] > 0 || recvcnt[otask] > 0)
106 myMPI_Sendrecv(
PCHAR(sendbuf) + tsz * sdispls[otask], sendcnt[otask], sendtype, otask, tag,
107 PCHAR(recvbuf) + tsz * rdispls[otask], recvcnt[otask], recvtype, otask, tag, comm, MPI_STATUS_IGNORE);
112 if(sendtype != recvtype)
113 Terminate(
"bad MPI communication types");
115 while(lptask < nranks)
119 MPI_Request *requests = (MPI_Request *)
Mem.mymalloc(
"requests", 2 * nranks *
sizeof(MPI_Request));
122 if(recvcnt[rank] > 0)
123 memcpy(
PCHAR(recvbuf) + tsz * rdispls[rank],
PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
125 for(
int ngrp = 1; ngrp < lptask; ngrp++)
127 int otask = rank ^ ngrp;
129 if(recvcnt[otask] > 0)
130 MPI_Irecv(
PCHAR(recvbuf) + tsz * rdispls[otask], recvcnt[otask], recvtype, otask, tag, comm, &requests[n_requests++]);
133 for(
int ngrp = 1; ngrp < lptask; ngrp++)
135 int otask = rank ^ ngrp;
137 if(sendcnt[otask] > 0)
138 MPI_Issend(
PCHAR(sendbuf) + tsz * sdispls[otask], sendcnt[otask], sendtype, otask, tag, comm, &requests[n_requests++]);
141 MPI_Waitall(n_requests, requests, MPI_STATUSES_IGNORE);
142 Mem.myfree(requests);
144 else if(method == 10)
146 if(sendtype != recvtype)
147 Terminate(
"bad MPI communication types");
148 int *disp_at_sender = (
int *)
Mem.mymalloc(
"disp_at_sender", nranks *
sizeof(
int));
149 disp_at_sender[rank] = sdispls[rank];
152 MPI_Win_create(sdispls, nranks *
sizeof(MPI_INT),
sizeof(MPI_INT), MPI_INFO_NULL, comm, &win);
153 MPI_Win_fence(0, win);
154 for(
int i = 1; i < nranks; ++i)
156 int tgt = (rank + i) % nranks;
157 if(recvcnt[tgt] != 0)
158 MPI_Get(&disp_at_sender[tgt], 1, MPI_INT, tgt, rank, 1, MPI_INT, win);
160 MPI_Win_fence(0, win);
162 if(recvcnt[rank] > 0)
163 memcpy(
PCHAR(recvbuf) + tsz * rdispls[rank],
PCHAR(sendbuf) + tsz * sdispls[rank], tsz * recvcnt[rank]);
164 MPI_Win_create(sendbuf, (sdispls[nranks - 1] + sendcnt[nranks - 1]) * tsz, tsz, MPI_INFO_NULL, comm, &win);
165 MPI_Win_fence(0, win);
166 for(
int i = 1; i < nranks; ++i)
168 int tgt = (rank + i) % nranks;
169 if(recvcnt[tgt] != 0)
170 MPI_Get(
PCHAR(recvbuf) + tsz * rdispls[tgt], recvcnt[tgt], sendtype, tgt, disp_at_sender[tgt], recvcnt[tgt], sendtype,
173 MPI_Win_fence(0, win);
175 Mem.myfree(disp_at_sender);
181void myMPI_Alltoallv(
void *sendb,
size_t *sendcounts,
size_t *sdispls,
void *recvb,
size_t *recvcounts,
size_t *rdispls,
int len,
182 int big_flag, MPI_Comm comm)
184 char *sendbuf = (
char *)sendb;
185 char *recvbuf = (
char *)recvb;
187#ifndef MPI_HYPERCUBE_ALLTOALL
191 MPI_Comm_size(comm, &ntask);
193 int *scount = (
int *)
Mem.mymalloc(
"scount", ntask *
sizeof(
int));
194 int *rcount = (
int *)
Mem.mymalloc(
"rcount", ntask *
sizeof(
int));
195 int *soff = (
int *)
Mem.mymalloc(
"soff", ntask *
sizeof(
int));
196 int *roff = (
int *)
Mem.mymalloc(
"roff", ntask *
sizeof(
int));
198 for(
int i = 0; i < ntask; i++)
200 scount[i] = sendcounts[i] * len;
201 rcount[i] = recvcounts[i] * len;
202 soff[i] = sdispls[i] * len;
203 roff[i] = rdispls[i] * len;
206 MPI_Alltoallv(sendbuf, scount, soff, MPI_BYTE, recvbuf, rcount, roff, MPI_BYTE, comm);
219 int ntask, thistask, ptask;
220 MPI_Comm_size(comm, &ntask);
221 MPI_Comm_rank(comm, &thistask);
223 for(ptask = 0; ntask > (1 << ptask); ptask++)
226 for(
int ngrp = 0; ngrp < (1 << ptask); ngrp++)
228 int target = thistask ^ ngrp;
232 if(sendcounts[target] > 0 || recvcounts[target] > 0)
234 recvbuf + rdispls[target] * len, recvcounts[target] * len, MPI_BYTE, target,
TAG_PDATA + ngrp, comm,
241void my_int_MPI_Alltoallv(
void *sendb,
int *sendcounts,
int *sdispls,
void *recvb,
int *recvcounts,
int *rdispls,
int len,
242 int big_flag, MPI_Comm comm)
244 char *sendbuf = (
char *)sendb;
245 char *recvbuf = (
char *)recvb;
247#ifndef MPI_HYPERCUBE_ALLTOALL
251 MPI_Comm_size(comm, &ntask);
253 int *scount = (
int *)
Mem.mymalloc(
"scount", ntask *
sizeof(
int));
254 int *rcount = (
int *)
Mem.mymalloc(
"rcount", ntask *
sizeof(
int));
255 int *soff = (
int *)
Mem.mymalloc(
"soff", ntask *
sizeof(
int));
256 int *roff = (
int *)
Mem.mymalloc(
"roff", ntask *
sizeof(
int));
258 for(
int i = 0; i < ntask; i++)
260 scount[i] = sendcounts[i] * len;
261 rcount[i] = recvcounts[i] * len;
262 soff[i] = sdispls[i] * len;
263 roff[i] = rdispls[i] * len;
266 MPI_Alltoallv(sendbuf, scount, soff, MPI_BYTE, recvbuf, rcount, roff, MPI_BYTE, comm);
279 int ntask, thistask, ptask;
280 MPI_Comm_size(comm, &ntask);
281 MPI_Comm_rank(comm, &thistask);
283 for(ptask = 0; ntask > (1 << ptask); ptask++)
286 for(
int ngrp = 0; ngrp < (1 << ptask); ngrp++)
288 int target = thistask ^ ngrp;
292 if(sendcounts[target] > 0 || recvcounts[target] > 0)
294 recvbuf + rdispls[target] * len, recvcounts[target] * len, MPI_BYTE, target,
TAG_PDATA + ngrp, comm,
301int myMPI_Alltoall(
const void *sendbuf,
int sendcount, MPI_Datatype sendtype,
void *recvbuf,
int recvcount, MPI_Datatype recvtype,
305#ifndef MPI_HYPERCUBE_ALLTOALL
306 return MPI_Alltoall(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm);
308 int ntask, ptask, thistask, size_sendtype, size_recvtype;
310 MPI_Comm_rank(comm, &thistask);
311 MPI_Comm_size(comm, &ntask);
313 MPI_Type_size(sendtype, &size_sendtype);
314 MPI_Type_size(recvtype, &size_recvtype);
316 for(ptask = 0; ntask > (1 << ptask); ptask++)
319 for(
int ngrp = 1; ngrp < (1 << ptask); ngrp++)
321 int recvtask = thistask ^ ngrp;
324 myMPI_Sendrecv((
char *)sendbuf + recvtask * sendcount * size_sendtype, sendcount, sendtype, recvtask,
TAG_PDATA + ngrp,
325 (
char *)recvbuf + recvtask * recvcount * size_recvtype, recvcount, recvtype, recvtask,
TAG_PDATA + ngrp, comm,
329 memcpy((
char *)recvbuf + thistask * recvcount * size_recvtype, (
char *)sendbuf + thistask * sendcount * size_sendtype,
330 sendcount * size_sendtype);
int myMPI_Sendrecv(void *sendbuf, size_t sendcount, MPI_Datatype sendtype, int dest, int sendtag, void *recvbuf, size_t recvcount, MPI_Datatype recvtype, int source, int recvtag, MPI_Comm comm, MPI_Status *status)
void my_int_MPI_Alltoallv(void *sendb, int *sendcounts, int *sdispls, void *recvb, int *recvcounts, int *rdispls, int len, int big_flag, MPI_Comm comm)
int myMPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
void myMPI_Alltoallv(void *sendb, size_t *sendcounts, size_t *sdispls, void *recvb, size_t *recvcounts, size_t *rdispls, int len, int big_flag, MPI_Comm comm)
int myMPI_Alltoallv_new_prep(int *sendcnt, int *recvcnt, int *rdispls, MPI_Comm comm, int method)
void myMPI_Alltoallv_new(void *sendbuf, int *sendcnt, int *sdispls, MPI_Datatype sendtype, void *recvbuf, int *recvcnt, int *rdispls, MPI_Datatype recvtype, MPI_Comm comm, int method)