GADGET-4
parallel_sort.h
Go to the documentation of this file.
1/*******************************************************************************
2 * \copyright This file is part of the GADGET4 N-body/SPH code developed
3 * \copyright by Volker Springel. Copyright (C) 2014-2020 by Volker Springel
4 * \copyright (vspringel@mpa-garching.mpg.de) and all contributing authors.
5 *******************************************************************************/
6
12#ifndef PARALLEL_SORT_H
13#define PARALLEL_SORT_H
14
15#include "cxxsort.h"
16
17#include "../data/mymalloc.h"
18
19//#define CHECK_LOCAL_RANK
20
21template <typename It, typename Comp>
23{
24 private:
25 It begin;
26 Comp comp;
27
28 public:
29 IdxComp__(It begin_, Comp comp_) : begin(begin_), comp(comp_) {}
30 bool operator()(std::size_t a, std::size_t b) const { return comp(*(begin + a), *(begin + b)); }
31};
32
36template <typename It, typename T2, typename Comp>
37inline void buildIndex(It begin, It end, T2 *idx, Comp comp)
38{
39 using namespace std;
40 T2 num = end - begin;
41 for(T2 i = 0; i < num; ++i)
42 idx[i] = i;
43 mycxxsort(idx, idx + num, IdxComp__<It, Comp>(begin, comp));
44}
45
46template <typename T, typename Comp>
47void get_local_rank(const T &element, std::size_t tie_breaking_rank, const T *base, size_t nmemb, size_t noffs_thistask,
48 long long left, long long right, size_t *loc, Comp comp)
49{
50 if(right < left)
51 Terminate("right < left");
52
53 if(left == 0 && right == (int)nmemb + 1)
54 {
55 if(comp(base[nmemb - 1], element))
56 {
57 *loc = nmemb;
58 return;
59 }
60 else if(comp(element, base[0]))
61 {
62 *loc = 0;
63 return;
64 }
65 }
66
67 if(right == left) /* looks like we already converged to the proper rank */
68 {
69 *loc = left;
70 }
71 else
72 {
73 if(comp(base[right - 1], element)) /* the last element is smaller, hence all elements are on the left */
74 *loc = (right - 1) + 1;
75 else if(comp(element, base[left])) /* the first element is already larger, hence no element is on the left */
76 *loc = left;
77 else
78 {
79 while(right > left)
80 {
81 long long mid = ((right - 1) + left) / 2;
82
83 int cmp = comp(base[mid], element) ? -1 : (comp(element, base[mid]) ? +1 : 0);
84 if(cmp == 0)
85 {
86 if(mid + noffs_thistask < tie_breaking_rank)
87 cmp = -1;
88 else if(mid + noffs_thistask > tie_breaking_rank)
89 cmp = +1;
90 }
91
92 if(cmp == 0) /* element has exactly been found */
93 {
94 *loc = mid;
95 break;
96 }
97
98 if((right - 1) == left) /* elements is not on this CPU */
99 {
100 if(cmp < 0)
101 *loc = mid + 1;
102 else
103 *loc = mid;
104 break;
105 }
106
107 if(cmp < 0)
108 {
109 left = mid + 1;
110 }
111 else
112 {
113 if((right - 1) == left + 1)
114 {
115 if(mid != left)
116 Terminate("Can't be: -->left=%lld right=%lld\n", left, right);
117
118 *loc = left;
119 break;
120 }
121
122 right = mid;
123 }
124 }
125 }
126 }
127}
128
129#ifdef CHECK_LOCAL_RANK
130template <typename T, typename Comp>
131inline void check_local_rank(const T &element, /* element of which we want the rank */
132 size_t tie_breaking_rank, /* the initial global rank of this element (needed for breaking ties) */
133 const T *base, /* base address of local data */
134 size_t nmemb, /* number and size of local data */
135 size_t noffs_thistask, /* cumulative length of data on lower tasks */
136 long long left, long long right, /* range of elements on local task that may hold the element */
137 size_t loc, Comp comp) /* user-specified comparison function */
138{
139 long long count = 0;
140
141 for(size_t i = 0; i < nmemb; i++)
142 {
143 int cmp = comp(base[i], element) ? -1 : (comp(element, base[i]) ? +1 : 0);
144
145 if(cmp == 0)
146 {
147 if(noffs_thistask + i < tie_breaking_rank)
148 cmp = -1;
149 }
150
151 if(cmp < 0)
152 count++;
153 }
154
155 if(count != (long long)loc)
156 Terminate("Inconsistency: loc=%lld count=%lld left=%lld right=%lld nmemb=%lld\n", (long long)loc, count, left, right,
157 (long long)nmemb);
158}
159#endif
160
161template <typename T, typename Comp>
162inline double mycxxsort_parallel(T *begin, T *end, Comp comp, MPI_Comm comm)
163{
164 const int MAX_ITER_PARALLEL_SORT = 500;
165 int ranks_not_found, Local_ThisTask, Local_NTask, Color, new_max_loc;
166 size_t tie_breaking_rank, new_tie_breaking_rank, rank;
167 MPI_Comm MPI_CommLocal;
168
169 double ta = Logs.second();
170 size_t nmemb = end - begin;
171 size_t size = sizeof(T);
172 /* do a serial sort of the local data up front */
173 mycxxsort(begin, end, comp);
174
175 /* we create a communicator that contains just those tasks with nmemb > 0. This makes
176 * it easier to deal with CPUs that do not hold any data.
177 */
178 if(nmemb)
179 Color = 1;
180 else
181 Color = 0;
182
183 int thistask;
184 MPI_Comm_rank(comm, &thistask);
185
186 MPI_Comm_split(comm, Color, thistask, &MPI_CommLocal);
187 MPI_Comm_rank(MPI_CommLocal, &Local_ThisTask);
188 MPI_Comm_size(MPI_CommLocal, &Local_NTask);
189
190 if(Local_NTask > 1 && Color == 1)
191 {
192 size_t *nlist = (size_t *)Mem.mymalloc("nlist", Local_NTask * sizeof(size_t));
193 size_t *noffs = (size_t *)Mem.mymalloc("noffs", Local_NTask * sizeof(size_t));
194
195 MPI_Allgather(&nmemb, sizeof(size_t), MPI_BYTE, nlist, sizeof(size_t), MPI_BYTE, MPI_CommLocal);
196
197 noffs[0] = 0;
198 for(int i = 1; i < Local_NTask; i++)
199 noffs[i] = noffs[i - 1] + nlist[i - 1];
200
201 T *element_guess = (T *)Mem.mymalloc("element_guess", Local_NTask * size);
202 size_t *element_tie_breaking_rank = (size_t *)Mem.mymalloc("element_tie_breaking_rank", Local_NTask * sizeof(size_t));
203 size_t *desired_glob_rank = (size_t *)Mem.mymalloc("desired_glob_rank", Local_NTask * sizeof(size_t));
204 size_t *current_glob_rank = (size_t *)Mem.mymalloc("current_glob_rank", Local_NTask * sizeof(size_t));
205 size_t *current_loc_rank = (size_t *)Mem.mymalloc("current_loc_rank", Local_NTask * sizeof(size_t));
206 long long *range_left = (long long *)Mem.mymalloc("range_left", Local_NTask * sizeof(long long));
207 long long *range_right = (long long *)Mem.mymalloc("range_right", Local_NTask * sizeof(long long));
208 int *max_loc = (int *)Mem.mymalloc("max_loc", Local_NTask * sizeof(int));
209
210 size_t *list = (size_t *)Mem.mymalloc("list", Local_NTask * sizeof(size_t));
211 size_t *range_len_list = (size_t *)Mem.mymalloc("range_len_list", Local_NTask * sizeof(long long));
212 T median_element;
213 T *median_element_list = (T *)Mem.mymalloc("median_element_list", Local_NTask * size);
214 size_t *tie_breaking_rank_list = (size_t *)Mem.mymalloc("tie_breaking_rank_list", Local_NTask * sizeof(size_t));
215 int *index_list = (int *)Mem.mymalloc("index_list", Local_NTask * sizeof(int));
216 int *max_loc_list = (int *)Mem.mymalloc("max_loc_list", Local_NTask * sizeof(int));
217 size_t *source_range_len_list = (size_t *)Mem.mymalloc("source_range_len_list", Local_NTask * sizeof(long long));
218 size_t *source_tie_breaking_rank_list = (size_t *)Mem.mymalloc("source_tie_breaking_rank_list", Local_NTask * sizeof(long long));
219 T *source_median_element_list = (T *)Mem.mymalloc("source_median_element_list", Local_NTask * size);
220 T new_element_guess;
221
222 for(int i = 0; i < Local_NTask - 1; i++)
223 {
224 desired_glob_rank[i] = noffs[i + 1];
225 current_glob_rank[i] = 0;
226 range_left[i] = 0; /* first element that it can be */
227 range_right[i] = nmemb; /* first element that it can not be */
228 }
229
230 /* now we determine the first split element guess, which is the same for all divisions in the first iteration */
231
232 /* find the median of each processor, and then take the median among those values.
233 * This should work reasonably well even for extremely skewed distributions
234 */
235 long long range_len = range_right[0] - range_left[0];
236
237 if(range_len >= 1)
238 {
239 long long mid = (range_left[0] + range_right[0]) / 2;
240 median_element = begin[mid];
241 tie_breaking_rank = mid + noffs[Local_ThisTask];
242 }
243
244 MPI_Gather(&range_len, sizeof(long long), MPI_BYTE, range_len_list, sizeof(long long), MPI_BYTE, 0, MPI_CommLocal);
245 MPI_Gather(&median_element, size, MPI_BYTE, median_element_list, size, MPI_BYTE, 0, MPI_CommLocal);
246 MPI_Gather(&tie_breaking_rank, sizeof(size_t), MPI_BYTE, tie_breaking_rank_list, sizeof(size_t), MPI_BYTE, 0, MPI_CommLocal);
247
248 if(Local_ThisTask == 0)
249 {
250 for(int j = 0; j < Local_NTask; j++)
251 max_loc_list[j] = j;
252
253 /* eliminate the elements that are undefined because the corresponding CPU has zero range left */
254 int nleft = Local_NTask;
255
256 for(int j = 0; j < nleft; j++)
257 {
258 if(range_len_list[j] < 1)
259 {
260 range_len_list[j] = range_len_list[nleft - 1];
261 if(range_len_list[nleft - 1] >= 1 && j != (nleft - 1))
262 {
263 median_element_list[j] = median_element_list[nleft - 1];
264 tie_breaking_rank_list[j] = tie_breaking_rank_list[nleft - 1];
265 max_loc_list[j] = max_loc_list[nleft - 1];
266 }
267
268 nleft--;
269 j--;
270 }
271 }
272
273 /* do a serial sort of the remaining elements (indirectly, so that we have the order of tie breaking list as well) */
274 buildIndex(median_element_list, median_element_list + nleft, index_list, comp);
275
276 /* now select the median of the medians */
277 int mid = nleft / 2;
278 element_guess[0] = median_element_list[index_list[mid]];
279 element_tie_breaking_rank[0] = tie_breaking_rank_list[index_list[mid]];
280 max_loc[0] = max_loc_list[index_list[mid]];
281 }
282
283 MPI_Bcast(element_guess, size, MPI_BYTE, 0, MPI_CommLocal);
284 MPI_Bcast(&element_tie_breaking_rank[0], sizeof(size_t), MPI_BYTE, 0, MPI_CommLocal);
285 MPI_Bcast(&max_loc[0], 1, MPI_INT, 0, MPI_CommLocal);
286
287 for(int i = 1; i < Local_NTask - 1; i++)
288 {
289 element_guess[i] = element_guess[0];
290 element_tie_breaking_rank[i] = element_tie_breaking_rank[0];
291 max_loc[i] = max_loc[0];
292 }
293
294 int iter = 0;
295
296 do
297 {
298 for(int i = 0; i < Local_NTask - 1; i++)
299 {
300 if(current_glob_rank[i] != desired_glob_rank[i])
301 {
302 get_local_rank(element_guess[i], element_tie_breaking_rank[i], begin, nmemb, noffs[Local_ThisTask], range_left[i],
303 range_right[i], &current_loc_rank[i], comp);
304
305#ifdef CHECK_LOCAL_RANK
306 check_local_rank(element_guess[i], element_tie_breaking_rank[i], begin, nmemb, noffs[Local_ThisTask], range_left[i],
307 range_right[i], current_loc_rank[i], comp);
308#endif
309 }
310 }
311
312 /* now compute the global ranks by summing the local ranks */
313 /* Note: the last element in current_loc_rank is not defined. It will be summed by the last processor, and stored in the last
314 * element of current_glob_rank */
315 myMPI_Alltoall(current_loc_rank, sizeof(size_t), MPI_BYTE, list, sizeof(size_t), MPI_BYTE, MPI_CommLocal);
316 rank = 0;
317 for(int j = 0; j < Local_NTask; j++)
318 rank += list[j];
319 MPI_Allgather(&rank, sizeof(size_t), MPI_BYTE, current_glob_rank, sizeof(size_t), MPI_BYTE, MPI_CommLocal);
320
321 ranks_not_found = 0;
322 for(int i = 0; i < Local_NTask - 1; i++)
323 {
324 if(current_glob_rank[i] != desired_glob_rank[i]) /* here we're not yet done */
325 {
326 ranks_not_found++;
327
328 if(current_glob_rank[i] < desired_glob_rank[i])
329 {
330 range_left[i] = current_loc_rank[i];
331
332 if(Local_ThisTask == max_loc[i])
333 range_left[i]++;
334 }
335
336 if(current_glob_rank[i] > desired_glob_rank[i])
337 range_right[i] = current_loc_rank[i];
338 }
339 }
340
341 /* now we need to determine new element guesses */
342 for(int i = 0; i < Local_NTask - 1; i++)
343 {
344 if(current_glob_rank[i] != desired_glob_rank[i]) /* here we're not yet done */
345 {
346 /* find the median of each processor, and then take the median among those values.
347 * This should work reasonably well even for extremely skewed distributions
348 */
349 source_range_len_list[i] = range_right[i] - range_left[i];
350
351 if(source_range_len_list[i] >= 1)
352 {
353 long long middle = (range_left[i] + range_right[i]) / 2;
354 source_median_element_list[i] = begin[middle];
355 source_tie_breaking_rank_list[i] = middle + noffs[Local_ThisTask];
356 }
357 }
358 }
359
360 myMPI_Alltoall(source_range_len_list, sizeof(long long), MPI_BYTE, range_len_list, sizeof(long long), MPI_BYTE, MPI_CommLocal);
361 myMPI_Alltoall(source_median_element_list, size, MPI_BYTE, median_element_list, size, MPI_BYTE, MPI_CommLocal);
362 myMPI_Alltoall(source_tie_breaking_rank_list, sizeof(size_t), MPI_BYTE, tie_breaking_rank_list, sizeof(size_t), MPI_BYTE,
363 MPI_CommLocal);
364
365 if(Local_ThisTask < Local_NTask - 1)
366 {
367 if(current_glob_rank[Local_ThisTask] !=
368 desired_glob_rank[Local_ThisTask]) /* in this case we're not yet done for this split point */
369 {
370 for(int j = 0; j < Local_NTask; j++)
371 max_loc_list[j] = j;
372
373 /* eliminate the elements that are undefined because the corresponding CPU has zero range left */
374 int nleft = Local_NTask;
375
376 for(int j = 0; j < nleft; j++)
377 {
378 if(range_len_list[j] < 1)
379 {
380 range_len_list[j] = range_len_list[nleft - 1];
381 if(range_len_list[nleft - 1] >= 1 && j != (nleft - 1))
382 {
383 median_element_list[j] = median_element_list[nleft - 1];
384 tie_breaking_rank_list[j] = tie_breaking_rank_list[nleft - 1];
385 max_loc_list[j] = max_loc_list[nleft - 1];
386 }
387
388 nleft--;
389 j--;
390 }
391 }
392
393 if((iter & 1))
394 {
395 size_t max_range = 0, maxj = 0;
396
397 for(int j = 0; j < nleft; j++)
398 if(range_len_list[j] > max_range)
399 {
400 max_range = range_len_list[j];
401 maxj = j;
402 }
403
404 /* now select the median element from the task which has the largest range */
405 new_element_guess = median_element_list[maxj];
406 new_tie_breaking_rank = tie_breaking_rank_list[maxj];
407 new_max_loc = max_loc_list[maxj];
408 }
409 else
410 {
411 /* do a serial sort of the remaining elements (indirectly, so that we have the order of tie breaking list as
412 * well) */
413 buildIndex(median_element_list, median_element_list + nleft, index_list, comp);
414
415 /* now select the median of the medians */
416 int mid = nleft / 2;
417 new_element_guess = median_element_list[index_list[mid]];
418 new_tie_breaking_rank = tie_breaking_rank_list[index_list[mid]];
419 new_max_loc = max_loc_list[index_list[mid]];
420 }
421 }
422 else
423 {
424 /* in order to preserve existing guesses */
425 new_element_guess = element_guess[Local_ThisTask];
426 new_tie_breaking_rank = element_tie_breaking_rank[Local_ThisTask];
427 new_max_loc = max_loc[Local_ThisTask];
428 }
429 }
430
431 MPI_Allgather(&new_element_guess, size, MPI_BYTE, element_guess, size, MPI_BYTE, MPI_CommLocal);
432 MPI_Allgather(&new_tie_breaking_rank, sizeof(size_t), MPI_BYTE, element_tie_breaking_rank, sizeof(size_t), MPI_BYTE,
433 MPI_CommLocal);
434 MPI_Allgather(&new_max_loc, 1, MPI_INT, max_loc, 1, MPI_INT, MPI_CommLocal);
435
436 iter++;
437
438 if(iter > (MAX_ITER_PARALLEL_SORT - 100) && Local_ThisTask == 0)
439 {
440 printf("PSORT: iter=%d: ranks_not_found=%d Local_NTask=%d\n", iter, ranks_not_found, Local_NTask);
441 myflush(stdout);
442 if(iter > MAX_ITER_PARALLEL_SORT)
443 Terminate("can't find the split points. That's odd");
444 }
445 }
446 while(ranks_not_found);
447
448 Mem.myfree(source_median_element_list);
449 Mem.myfree(source_tie_breaking_rank_list);
450 Mem.myfree(source_range_len_list);
451 Mem.myfree(max_loc_list);
452 Mem.myfree(index_list);
453 Mem.myfree(tie_breaking_rank_list);
454 Mem.myfree(median_element_list);
455
456 /* At this point we have found all the elements corresponding to the desired split points */
457 /* we can now go ahead and determine how many elements of the local CPU have to go to each other CPU */
458
459 size_t *send_count = (size_t *)Mem.mymalloc("send_count", Local_NTask * sizeof(size_t));
460 size_t *recv_count = (size_t *)Mem.mymalloc("recv_count", Local_NTask * sizeof(size_t));
461 size_t *send_offset = (size_t *)Mem.mymalloc("send_offset", Local_NTask * sizeof(size_t));
462 size_t *recv_offset = (size_t *)Mem.mymalloc("recv_offset", Local_NTask * sizeof(size_t));
463
464 for(int i = 0; i < Local_NTask; i++)
465 send_count[i] = 0;
466
467 int target = 0;
468
469 for(size_t i = 0; i < nmemb; i++)
470 {
471 while(target < Local_NTask - 1)
472 {
473 int cmp = comp(begin[i], element_guess[target]) ? -1 : (comp(element_guess[target], begin[i]) ? +1 : 0);
474 if(cmp == 0)
475 {
476 if(i + noffs[Local_ThisTask] < element_tie_breaking_rank[target])
477 cmp = -1;
478 else if(i + noffs[Local_ThisTask] > element_tie_breaking_rank[target])
479 cmp = +1;
480 }
481 if(cmp >= 0)
482 target++;
483 else
484 break;
485 }
486 send_count[target]++;
487 }
488
489 myMPI_Alltoall(send_count, sizeof(size_t), MPI_BYTE, recv_count, sizeof(size_t), MPI_BYTE, MPI_CommLocal);
490
491 size_t nimport = 0;
492
493 recv_offset[0] = 0;
494 send_offset[0] = 0;
495 for(int j = 0; j < Local_NTask; j++)
496 {
497 nimport += recv_count[j];
498
499 if(j > 0)
500 {
501 send_offset[j] = send_offset[j - 1] + send_count[j - 1];
502 recv_offset[j] = recv_offset[j - 1] + recv_count[j - 1];
503 }
504 }
505
506 if(nimport != nmemb)
507 Terminate("nimport=%lld != nmemb=%lld", (long long)nimport, (long long)nmemb);
508
509 for(int j = 0; j < Local_NTask; j++)
510 {
511 send_count[j] *= size;
512 recv_count[j] *= size;
513
514 send_offset[j] *= size;
515 recv_offset[j] *= size;
516 }
517
518 T *basetmp = (T *)Mem.mymalloc("basetmp", nmemb * size);
519
520 /* exchange the data */
521 myMPI_Alltoallv(begin, send_count, send_offset, basetmp, recv_count, recv_offset, sizeof(char), 1, MPI_CommLocal);
522
523 memcpy(static_cast<void *>(begin), static_cast<void *>(basetmp), nmemb * size);
524 Mem.myfree(basetmp);
525
526 mycxxsort(begin, begin + nmemb, comp);
527
528 Mem.myfree(recv_offset);
529 Mem.myfree(send_offset);
530 Mem.myfree(recv_count);
531 Mem.myfree(send_count);
532
533 Mem.myfree(range_len_list);
534 Mem.myfree(list);
535 Mem.myfree(max_loc);
536 Mem.myfree(range_right);
537 Mem.myfree(range_left);
538 Mem.myfree(current_loc_rank);
539 Mem.myfree(current_glob_rank);
540 Mem.myfree(desired_glob_rank);
541 Mem.myfree(element_tie_breaking_rank);
542 Mem.myfree(element_guess);
543 Mem.myfree(noffs);
544 Mem.myfree(nlist);
545 }
546
547 MPI_Comm_free(&MPI_CommLocal);
548
549 double tb = Logs.second();
550 return Logs.timediff(ta, tb);
551}
552
553#endif
bool operator()(std::size_t a, std::size_t b) const
Definition: parallel_sort.h:30
IdxComp__(It begin_, Comp comp_)
Definition: parallel_sort.h:29
double timediff(double t0, double t1)
Definition: logs.cc:488
double second(void)
Definition: logs.cc:471
various sort routines
double mycxxsort(T *begin, T *end, Tcomp comp)
Definition: cxxsort.h:39
logs Logs
Definition: main.cc:43
#define Terminate(...)
Definition: macros.h:15
int myMPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
Definition: myalltoall.cc:301
void myMPI_Alltoallv(void *sendbuf, size_t *sendcounts, size_t *sdispls, void *recvbuf, size_t *recvcounts, size_t *rdispls, int len, int big_flag, MPI_Comm comm)
Definition: myalltoall.cc:181
memory Mem
Definition: main.cc:44
STL namespace.
void get_local_rank(const T &element, std::size_t tie_breaking_rank, const T *base, size_t nmemb, size_t noffs_thistask, long long left, long long right, size_t *loc, Comp comp)
Definition: parallel_sort.h:47
double mycxxsort_parallel(T *begin, T *end, Comp comp, MPI_Comm comm)
void buildIndex(It begin, It end, T2 *idx, Comp comp)
Definition: parallel_sort.h:37
void myflush(FILE *fstream)
Definition: system.cc:125