12#ifndef PARALLEL_SORT_H
13#define PARALLEL_SORT_H
17#include "../data/mymalloc.h"
21template <
typename It,
typename Comp>
29 IdxComp__(It begin_, Comp comp_) : begin(begin_), comp(comp_) {}
30 bool operator()(std::size_t a, std::size_t b)
const {
return comp(*(begin + a), *(begin + b)); }
36template <
typename It,
typename T2,
typename Comp>
37inline void buildIndex(It begin, It end, T2 *idx, Comp comp)
41 for(T2 i = 0; i < num; ++i)
46template <
typename T,
typename Comp>
47void get_local_rank(
const T &element, std::size_t tie_breaking_rank,
const T *base,
size_t nmemb,
size_t noffs_thistask,
48 long long left,
long long right,
size_t *loc, Comp comp)
53 if(left == 0 && right == (
int)nmemb + 1)
55 if(comp(base[nmemb - 1], element))
60 else if(comp(element, base[0]))
73 if(comp(base[right - 1], element))
74 *loc = (right - 1) + 1;
75 else if(comp(element, base[left]))
81 long long mid = ((right - 1) + left) / 2;
83 int cmp = comp(base[mid], element) ? -1 : (comp(element, base[mid]) ? +1 : 0);
86 if(mid + noffs_thistask < tie_breaking_rank)
88 else if(mid + noffs_thistask > tie_breaking_rank)
98 if((right - 1) == left)
113 if((right - 1) == left + 1)
116 Terminate(
"Can't be: -->left=%lld right=%lld\n", left, right);
129#ifdef CHECK_LOCAL_RANK
130template <
typename T,
typename Comp>
131inline void check_local_rank(
const T &element,
132 size_t tie_breaking_rank,
135 size_t noffs_thistask,
136 long long left,
long long right,
137 size_t loc, Comp comp)
141 for(
size_t i = 0; i < nmemb; i++)
143 int cmp = comp(base[i], element) ? -1 : (comp(element, base[i]) ? +1 : 0);
147 if(noffs_thistask + i < tie_breaking_rank)
155 if(count != (
long long)loc)
156 Terminate(
"Inconsistency: loc=%lld count=%lld left=%lld right=%lld nmemb=%lld\n", (
long long)loc, count, left, right,
161template <
typename T,
typename Comp>
164 const int MAX_ITER_PARALLEL_SORT = 500;
165 int ranks_not_found, Local_ThisTask, Local_NTask, Color, new_max_loc;
166 size_t tie_breaking_rank, new_tie_breaking_rank, rank;
167 MPI_Comm MPI_CommLocal;
170 size_t nmemb = end - begin;
171 size_t size =
sizeof(T);
184 MPI_Comm_rank(comm, &thistask);
186 MPI_Comm_split(comm, Color, thistask, &MPI_CommLocal);
187 MPI_Comm_rank(MPI_CommLocal, &Local_ThisTask);
188 MPI_Comm_size(MPI_CommLocal, &Local_NTask);
190 if(Local_NTask > 1 && Color == 1)
192 size_t *nlist = (
size_t *)
Mem.mymalloc(
"nlist", Local_NTask *
sizeof(
size_t));
193 size_t *noffs = (
size_t *)
Mem.mymalloc(
"noffs", Local_NTask *
sizeof(
size_t));
195 MPI_Allgather(&nmemb,
sizeof(
size_t), MPI_BYTE, nlist,
sizeof(
size_t), MPI_BYTE, MPI_CommLocal);
198 for(
int i = 1; i < Local_NTask; i++)
199 noffs[i] = noffs[i - 1] + nlist[i - 1];
201 T *element_guess = (T *)
Mem.mymalloc(
"element_guess", Local_NTask * size);
202 size_t *element_tie_breaking_rank = (
size_t *)
Mem.mymalloc(
"element_tie_breaking_rank", Local_NTask *
sizeof(
size_t));
203 size_t *desired_glob_rank = (
size_t *)
Mem.mymalloc(
"desired_glob_rank", Local_NTask *
sizeof(
size_t));
204 size_t *current_glob_rank = (
size_t *)
Mem.mymalloc(
"current_glob_rank", Local_NTask *
sizeof(
size_t));
205 size_t *current_loc_rank = (
size_t *)
Mem.mymalloc(
"current_loc_rank", Local_NTask *
sizeof(
size_t));
206 long long *range_left = (
long long *)
Mem.mymalloc(
"range_left", Local_NTask *
sizeof(
long long));
207 long long *range_right = (
long long *)
Mem.mymalloc(
"range_right", Local_NTask *
sizeof(
long long));
208 int *max_loc = (
int *)
Mem.mymalloc(
"max_loc", Local_NTask *
sizeof(
int));
210 size_t *list = (
size_t *)
Mem.mymalloc(
"list", Local_NTask *
sizeof(
size_t));
211 size_t *range_len_list = (
size_t *)
Mem.mymalloc(
"range_len_list", Local_NTask *
sizeof(
long long));
213 T *median_element_list = (T *)
Mem.mymalloc(
"median_element_list", Local_NTask * size);
214 size_t *tie_breaking_rank_list = (
size_t *)
Mem.mymalloc(
"tie_breaking_rank_list", Local_NTask *
sizeof(
size_t));
215 int *index_list = (
int *)
Mem.mymalloc(
"index_list", Local_NTask *
sizeof(
int));
216 int *max_loc_list = (
int *)
Mem.mymalloc(
"max_loc_list", Local_NTask *
sizeof(
int));
217 size_t *source_range_len_list = (
size_t *)
Mem.mymalloc(
"source_range_len_list", Local_NTask *
sizeof(
long long));
218 size_t *source_tie_breaking_rank_list = (
size_t *)
Mem.mymalloc(
"source_tie_breaking_rank_list", Local_NTask *
sizeof(
long long));
219 T *source_median_element_list = (T *)
Mem.mymalloc(
"source_median_element_list", Local_NTask * size);
222 for(
int i = 0; i < Local_NTask - 1; i++)
224 desired_glob_rank[i] = noffs[i + 1];
225 current_glob_rank[i] = 0;
227 range_right[i] = nmemb;
235 long long range_len = range_right[0] - range_left[0];
239 long long mid = (range_left[0] + range_right[0]) / 2;
240 median_element = begin[mid];
241 tie_breaking_rank = mid + noffs[Local_ThisTask];
244 MPI_Gather(&range_len,
sizeof(
long long), MPI_BYTE, range_len_list,
sizeof(
long long), MPI_BYTE, 0, MPI_CommLocal);
245 MPI_Gather(&median_element, size, MPI_BYTE, median_element_list, size, MPI_BYTE, 0, MPI_CommLocal);
246 MPI_Gather(&tie_breaking_rank,
sizeof(
size_t), MPI_BYTE, tie_breaking_rank_list,
sizeof(
size_t), MPI_BYTE, 0, MPI_CommLocal);
248 if(Local_ThisTask == 0)
250 for(
int j = 0; j < Local_NTask; j++)
254 int nleft = Local_NTask;
256 for(
int j = 0; j < nleft; j++)
258 if(range_len_list[j] < 1)
260 range_len_list[j] = range_len_list[nleft - 1];
261 if(range_len_list[nleft - 1] >= 1 && j != (nleft - 1))
263 median_element_list[j] = median_element_list[nleft - 1];
264 tie_breaking_rank_list[j] = tie_breaking_rank_list[nleft - 1];
265 max_loc_list[j] = max_loc_list[nleft - 1];
274 buildIndex(median_element_list, median_element_list + nleft, index_list, comp);
278 element_guess[0] = median_element_list[index_list[mid]];
279 element_tie_breaking_rank[0] = tie_breaking_rank_list[index_list[mid]];
280 max_loc[0] = max_loc_list[index_list[mid]];
283 MPI_Bcast(element_guess, size, MPI_BYTE, 0, MPI_CommLocal);
284 MPI_Bcast(&element_tie_breaking_rank[0],
sizeof(
size_t), MPI_BYTE, 0, MPI_CommLocal);
285 MPI_Bcast(&max_loc[0], 1, MPI_INT, 0, MPI_CommLocal);
287 for(
int i = 1; i < Local_NTask - 1; i++)
289 element_guess[i] = element_guess[0];
290 element_tie_breaking_rank[i] = element_tie_breaking_rank[0];
291 max_loc[i] = max_loc[0];
298 for(
int i = 0; i < Local_NTask - 1; i++)
300 if(current_glob_rank[i] != desired_glob_rank[i])
302 get_local_rank(element_guess[i], element_tie_breaking_rank[i], begin, nmemb, noffs[Local_ThisTask], range_left[i],
303 range_right[i], ¤t_loc_rank[i], comp);
305#ifdef CHECK_LOCAL_RANK
306 check_local_rank(element_guess[i], element_tie_breaking_rank[i], begin, nmemb, noffs[Local_ThisTask], range_left[i],
307 range_right[i], current_loc_rank[i], comp);
315 myMPI_Alltoall(current_loc_rank,
sizeof(
size_t), MPI_BYTE, list,
sizeof(
size_t), MPI_BYTE, MPI_CommLocal);
317 for(
int j = 0; j < Local_NTask; j++)
319 MPI_Allgather(&rank,
sizeof(
size_t), MPI_BYTE, current_glob_rank,
sizeof(
size_t), MPI_BYTE, MPI_CommLocal);
322 for(
int i = 0; i < Local_NTask - 1; i++)
324 if(current_glob_rank[i] != desired_glob_rank[i])
328 if(current_glob_rank[i] < desired_glob_rank[i])
330 range_left[i] = current_loc_rank[i];
332 if(Local_ThisTask == max_loc[i])
336 if(current_glob_rank[i] > desired_glob_rank[i])
337 range_right[i] = current_loc_rank[i];
342 for(
int i = 0; i < Local_NTask - 1; i++)
344 if(current_glob_rank[i] != desired_glob_rank[i])
349 source_range_len_list[i] = range_right[i] - range_left[i];
351 if(source_range_len_list[i] >= 1)
353 long long middle = (range_left[i] + range_right[i]) / 2;
354 source_median_element_list[i] = begin[middle];
355 source_tie_breaking_rank_list[i] = middle + noffs[Local_ThisTask];
360 myMPI_Alltoall(source_range_len_list,
sizeof(
long long), MPI_BYTE, range_len_list,
sizeof(
long long), MPI_BYTE, MPI_CommLocal);
361 myMPI_Alltoall(source_median_element_list, size, MPI_BYTE, median_element_list, size, MPI_BYTE, MPI_CommLocal);
362 myMPI_Alltoall(source_tie_breaking_rank_list,
sizeof(
size_t), MPI_BYTE, tie_breaking_rank_list,
sizeof(
size_t), MPI_BYTE,
365 if(Local_ThisTask < Local_NTask - 1)
367 if(current_glob_rank[Local_ThisTask] !=
368 desired_glob_rank[Local_ThisTask])
370 for(
int j = 0; j < Local_NTask; j++)
374 int nleft = Local_NTask;
376 for(
int j = 0; j < nleft; j++)
378 if(range_len_list[j] < 1)
380 range_len_list[j] = range_len_list[nleft - 1];
381 if(range_len_list[nleft - 1] >= 1 && j != (nleft - 1))
383 median_element_list[j] = median_element_list[nleft - 1];
384 tie_breaking_rank_list[j] = tie_breaking_rank_list[nleft - 1];
385 max_loc_list[j] = max_loc_list[nleft - 1];
395 size_t max_range = 0, maxj = 0;
397 for(
int j = 0; j < nleft; j++)
398 if(range_len_list[j] > max_range)
400 max_range = range_len_list[j];
405 new_element_guess = median_element_list[maxj];
406 new_tie_breaking_rank = tie_breaking_rank_list[maxj];
407 new_max_loc = max_loc_list[maxj];
413 buildIndex(median_element_list, median_element_list + nleft, index_list, comp);
417 new_element_guess = median_element_list[index_list[mid]];
418 new_tie_breaking_rank = tie_breaking_rank_list[index_list[mid]];
419 new_max_loc = max_loc_list[index_list[mid]];
425 new_element_guess = element_guess[Local_ThisTask];
426 new_tie_breaking_rank = element_tie_breaking_rank[Local_ThisTask];
427 new_max_loc = max_loc[Local_ThisTask];
431 MPI_Allgather(&new_element_guess, size, MPI_BYTE, element_guess, size, MPI_BYTE, MPI_CommLocal);
432 MPI_Allgather(&new_tie_breaking_rank,
sizeof(
size_t), MPI_BYTE, element_tie_breaking_rank,
sizeof(
size_t), MPI_BYTE,
434 MPI_Allgather(&new_max_loc, 1, MPI_INT, max_loc, 1, MPI_INT, MPI_CommLocal);
438 if(iter > (MAX_ITER_PARALLEL_SORT - 100) && Local_ThisTask == 0)
440 printf(
"PSORT: iter=%d: ranks_not_found=%d Local_NTask=%d\n", iter, ranks_not_found, Local_NTask);
442 if(iter > MAX_ITER_PARALLEL_SORT)
443 Terminate(
"can't find the split points. That's odd");
446 while(ranks_not_found);
448 Mem.myfree(source_median_element_list);
449 Mem.myfree(source_tie_breaking_rank_list);
450 Mem.myfree(source_range_len_list);
451 Mem.myfree(max_loc_list);
452 Mem.myfree(index_list);
453 Mem.myfree(tie_breaking_rank_list);
454 Mem.myfree(median_element_list);
459 size_t *send_count = (
size_t *)
Mem.mymalloc(
"send_count", Local_NTask *
sizeof(
size_t));
460 size_t *recv_count = (
size_t *)
Mem.mymalloc(
"recv_count", Local_NTask *
sizeof(
size_t));
461 size_t *send_offset = (
size_t *)
Mem.mymalloc(
"send_offset", Local_NTask *
sizeof(
size_t));
462 size_t *recv_offset = (
size_t *)
Mem.mymalloc(
"recv_offset", Local_NTask *
sizeof(
size_t));
464 for(
int i = 0; i < Local_NTask; i++)
469 for(
size_t i = 0; i < nmemb; i++)
471 while(target < Local_NTask - 1)
473 int cmp = comp(begin[i], element_guess[target]) ? -1 : (comp(element_guess[target], begin[i]) ? +1 : 0);
476 if(i + noffs[Local_ThisTask] < element_tie_breaking_rank[target])
478 else if(i + noffs[Local_ThisTask] > element_tie_breaking_rank[target])
486 send_count[target]++;
489 myMPI_Alltoall(send_count,
sizeof(
size_t), MPI_BYTE, recv_count,
sizeof(
size_t), MPI_BYTE, MPI_CommLocal);
495 for(
int j = 0; j < Local_NTask; j++)
497 nimport += recv_count[j];
501 send_offset[j] = send_offset[j - 1] + send_count[j - 1];
502 recv_offset[j] = recv_offset[j - 1] + recv_count[j - 1];
507 Terminate(
"nimport=%lld != nmemb=%lld", (
long long)nimport, (
long long)nmemb);
509 for(
int j = 0; j < Local_NTask; j++)
511 send_count[j] *= size;
512 recv_count[j] *= size;
514 send_offset[j] *= size;
515 recv_offset[j] *= size;
518 T *basetmp = (T *)
Mem.mymalloc(
"basetmp", nmemb * size);
521 myMPI_Alltoallv(begin, send_count, send_offset, basetmp, recv_count, recv_offset,
sizeof(
char), 1, MPI_CommLocal);
523 memcpy(
static_cast<void *
>(begin),
static_cast<void *
>(basetmp), nmemb * size);
528 Mem.myfree(recv_offset);
529 Mem.myfree(send_offset);
530 Mem.myfree(recv_count);
531 Mem.myfree(send_count);
533 Mem.myfree(range_len_list);
536 Mem.myfree(range_right);
537 Mem.myfree(range_left);
538 Mem.myfree(current_loc_rank);
539 Mem.myfree(current_glob_rank);
540 Mem.myfree(desired_glob_rank);
541 Mem.myfree(element_tie_breaking_rank);
542 Mem.myfree(element_guess);
547 MPI_Comm_free(&MPI_CommLocal);
bool operator()(std::size_t a, std::size_t b) const
IdxComp__(It begin_, Comp comp_)
double timediff(double t0, double t1)
double mycxxsort(T *begin, T *end, Tcomp comp)
int myMPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
void myMPI_Alltoallv(void *sendbuf, size_t *sendcounts, size_t *sdispls, void *recvbuf, size_t *recvcounts, size_t *rdispls, int len, int big_flag, MPI_Comm comm)
void get_local_rank(const T &element, std::size_t tie_breaking_rank, const T *base, size_t nmemb, size_t noffs_thistask, long long left, long long right, size_t *loc, Comp comp)
double mycxxsort_parallel(T *begin, T *end, Comp comp, MPI_Comm comm)
void buildIndex(It begin, It end, T2 *idx, Comp comp)
void myflush(FILE *fstream)