12#include "gadgetconfig.h"
14#if defined(PMGRID) || defined(NGENIC)
24#include "../data/allvars.h"
25#include "../data/dtypes.h"
26#include "../data/mymalloc.h"
27#include "../main/simulation.h"
28#include "../mpi_utils/mpi_utils.h"
30#include "../system/system.h"
38#ifndef FFT_COLUMN_BASED
45 plan->slab_to_task = (
int *)
Mem.mymalloc_movable(&plan->slab_to_task,
"slab_to_task", NgridX *
sizeof(
int));
47 for(
int task = 0; task <
NTask; task++)
53 for(
int i = start; i < start + n; i++)
54 plan->slab_to_task[i] = task;
57 MPI_Allreduce(&plan->nslab_x, &plan->largest_x_slab, 1, MPI_INT, MPI_MAX,
Communicator);
58 MPI_Allreduce(&plan->nslab_y, &plan->largest_y_slab, 1, MPI_INT, MPI_MAX,
Communicator);
60 plan->slabs_x_per_task = (
int *)
Mem.mymalloc_movable(&plan->slabs_x_per_task,
"slabs_x_per_task",
NTask *
sizeof(
int));
61 MPI_Allgather(&plan->nslab_x, 1, MPI_INT, plan->slabs_x_per_task, 1, MPI_INT,
Communicator);
63 plan->first_slab_x_of_task = (
int *)
Mem.mymalloc_movable(&plan->first_slab_x_of_task,
"first_slab_x_of_task",
NTask *
sizeof(
int));
64 MPI_Allgather(&plan->slabstart_x, 1, MPI_INT, plan->first_slab_x_of_task, 1, MPI_INT,
Communicator);
66 plan->slabs_y_per_task = (
int *)
Mem.mymalloc_movable(&plan->slabs_y_per_task,
"slabs_y_per_task",
NTask *
sizeof(
int));
67 MPI_Allgather(&plan->nslab_y, 1, MPI_INT, plan->slabs_y_per_task, 1, MPI_INT,
Communicator);
69 plan->first_slab_y_of_task = (
int *)
Mem.mymalloc_movable(&plan->first_slab_y_of_task,
"first_slab_y_of_task",
NTask *
sizeof(
int));
70 MPI_Allgather(&plan->slabstart_y, 1, MPI_INT, plan->first_slab_y_of_task, 1, MPI_INT,
Communicator);
72 plan->NgridX = NgridX;
73 plan->NgridY = NgridY;
74 plan->NgridZ = NgridZ;
76 int Ngridz = NgridZ / 2 + 1;
78 plan->Ngridz = Ngridz;
79 plan->Ngrid2 = 2 * Ngridz;
84 Mem.myfree(plan->first_slab_y_of_task);
85 Mem.myfree(plan->slabs_y_per_task);
86 Mem.myfree(plan->first_slab_x_of_task);
87 Mem.myfree(plan->slabs_x_per_task);
88 Mem.myfree(plan->slab_to_task);
103 int n, prod, task, flag_big = 0, flag_big_all = 0;
105 prod =
NTask * plan->nslab_x;
107 for(n = 0; n < prod; n++)
110 int task = n %
NTask;
114 for(y = plan->first_slab_y_of_task[task]; y < plan->first_slab_y_of_task[task] + plan->slabs_y_per_task[task]; y++)
115 memcpy(scratch + ((
size_t)plan->NgridZ) * (plan->first_slab_y_of_task[task] * plan->nslab_x +
116 x * plan->slabs_y_per_task[task] + (y - plan->first_slab_y_of_task[task])),
117 field + ((
size_t)plan->Ngrid2) * (plan->NgridY * x + y), plan->NgridZ *
sizeof(fft_real));
120 size_t *scount = (
size_t *)
Mem.mymalloc(
"scount",
NTask *
sizeof(
size_t));
121 size_t *rcount = (
size_t *)
Mem.mymalloc(
"rcount",
NTask *
sizeof(
size_t));
122 size_t *soff = (
size_t *)
Mem.mymalloc(
"soff",
NTask *
sizeof(
size_t));
123 size_t *roff = (
size_t *)
Mem.mymalloc(
"roff",
NTask *
sizeof(
size_t));
125 for(task = 0; task <
NTask; task++)
127 scount[task] = plan->nslab_x * plan->slabs_y_per_task[task] * (plan->NgridZ *
sizeof(fft_real));
128 rcount[task] = plan->nslab_y * plan->slabs_x_per_task[task] * (plan->NgridZ *
sizeof(fft_real));
130 soff[task] = plan->first_slab_y_of_task[task] * plan->nslab_x * (plan->NgridZ *
sizeof(fft_real));
131 roff[task] = plan->first_slab_x_of_task[task] * plan->nslab_y * (plan->NgridZ *
sizeof(fft_real));
137 MPI_Allreduce(&flag_big, &flag_big_all, 1, MPI_INT, MPI_MAX,
Communicator);
158 int n, prod, task, flag_big = 0, flag_big_all = 0;
160 size_t *scount = (
size_t *)
Mem.mymalloc(
"scount",
NTask *
sizeof(
size_t));
161 size_t *rcount = (
size_t *)
Mem.mymalloc(
"rcount",
NTask *
sizeof(
size_t));
162 size_t *soff = (
size_t *)
Mem.mymalloc(
"soff",
NTask *
sizeof(
size_t));
163 size_t *roff = (
size_t *)
Mem.mymalloc(
"roff",
NTask *
sizeof(
size_t));
165 for(task = 0; task <
NTask; task++)
167 rcount[task] = plan->nslab_x * plan->slabs_y_per_task[task] * (plan->NgridZ *
sizeof(fft_real));
168 scount[task] = plan->nslab_y * plan->slabs_x_per_task[task] * (plan->NgridZ *
sizeof(fft_real));
170 roff[task] = plan->first_slab_y_of_task[task] * plan->nslab_x * (plan->NgridZ *
sizeof(fft_real));
171 soff[task] = plan->first_slab_x_of_task[task] * plan->nslab_y * (plan->NgridZ *
sizeof(fft_real));
177 MPI_Allreduce(&flag_big, &flag_big_all, 1, MPI_INT, MPI_MAX,
Communicator);
186 prod =
NTask * plan->nslab_x;
188 for(n = 0; n < prod; n++)
191 int task = n %
NTask;
194 for(y = plan->first_slab_y_of_task[task]; y < plan->first_slab_y_of_task[task] + plan->slabs_y_per_task[task]; y++)
195 memcpy(field + ((
size_t)plan->Ngrid2) * (plan->NgridY * x + y),
196 scratch + ((
size_t)plan->NgridZ) * (plan->first_slab_y_of_task[task] * plan->nslab_x +
197 x * plan->slabs_y_per_task[task] + (y - plan->first_slab_y_of_task[task])),
198 plan->NgridZ *
sizeof(fft_real));
213void pm_mpi_fft::my_slab_transpose(
void *av,
void *bv,
int *sx,
int *firstx,
int *sy,
int *firsty,
int nx,
int ny,
int nz,
int mode)
215 char *a = (
char *)av;
216 char *b = (
char *)bv;
218 size_t *scount = (
size_t *)
Mem.mymalloc(
"scount",
NTask *
sizeof(
size_t));
219 size_t *rcount = (
size_t *)
Mem.mymalloc(
"rcount",
NTask *
sizeof(
size_t));
220 size_t *soff = (
size_t *)
Mem.mymalloc(
"soff",
NTask *
sizeof(
size_t));
221 size_t *roff = (
size_t *)
Mem.mymalloc(
"roff",
NTask *
sizeof(
size_t));
222 int i, n, prod, flag_big = 0, flag_big_all = 0;
224 for(i = 0; i <
NTask; i++)
226 scount[i] = sy[i] * sx[
ThisTask] * ((size_t)nz);
227 rcount[i] = sy[
ThisTask] * sx[i] * ((size_t)nz);
228 soff[i] = firsty[i] * sx[
ThisTask] * ((size_t)nz);
229 roff[i] = sy[
ThisTask] * firstx[i] * ((size_t)nz);
238 MPI_Allreduce(&flag_big, &flag_big_all, 1, MPI_INT, MPI_MAX,
Communicator);
244 for(n = 0; n < prod; n++)
250 for(j = 0; j < sy[i]; j++)
251 memcpy(b + (k * sy[i] + j + firsty[i] * sx[
ThisTask]) * (nz *
sizeof(fft_complex)),
252 a + (k * ny + (firsty[i] + j)) * (nz *
sizeof(fft_complex)), nz *
sizeof(fft_complex));
260 for(n = 0; n < prod; n++)
266 for(k = 0; k < sx[i]; k++)
267 memcpy(b + (j * nx + k + firstx[i]) * (nz *
sizeof(fft_complex)),
268 a + ((k + firstx[i]) * sy[
ThisTask] + j) * (nz *
sizeof(fft_complex)), nz *
sizeof(fft_complex));
275 for(n = 0; n < prod; n++)
281 for(k = 0; k < sx[i]; k++)
282 memcpy(b + ((k + firstx[i]) * sy[
ThisTask] + j) * (nz *
sizeof(fft_complex)),
283 a + (j * nx + k + firstx[i]) * (nz *
sizeof(fft_complex)), nz *
sizeof(fft_complex));
291 for(n = 0; n < prod; n++)
297 for(j = 0; j < sy[i]; j++)
298 memcpy(b + (k * ny + (firsty[i] + j)) * (nz *
sizeof(fft_complex)),
299 a + (k * sy[i] + j + firsty[i] * sx[
ThisTask]) * (nz *
sizeof(fft_complex)), nz *
sizeof(fft_complex));
314 int slabsx = plan->slabs_x_per_task[
ThisTask];
315 int slabsy = plan->slabs_y_per_task[
ThisTask];
317 int ngridx = plan->NgridX;
318 int ngridy = plan->NgridY;
319 int ngridz = plan->Ngridz;
320 int ngridz2 = 2 * ngridz;
322 size_t ngridx_long = ngridx;
323 size_t ngridy_long = ngridy;
324 size_t ngridz_long = ngridz;
325 size_t ngridz2_long = ngridz2;
327 fft_real *data_real = (fft_real *)data;
328 fft_complex *data_complex = (fft_complex *)data, *workspace_complex = (fft_complex *)workspace;
333 prod = slabsx * ngridy;
334 for(n = 0; n < prod; n++)
336 FFTW(execute_dft_r2c)(plan->forward_plan_zdir, data_real + n * ngridz2_long, workspace_complex + n * ngridz_long);
340 prod = slabsx * ngridz;
341 for(n = 0; n < prod; n++)
347 (plan->forward_plan_ydir, workspace_complex + i * ngridz * ngridy_long + j, data_complex + i * ngridz * ngridy_long + j);
353 my_slab_transpose(data_complex, workspace_complex, plan->slabs_x_per_task, plan->first_slab_x_of_task, plan->slabs_y_per_task,
354 plan->first_slab_y_of_task, ngridx, ngridy, ngridz, 0);
359 prod = slabsy * ngridz;
360 for(n = 0; n < prod; n++)
366 (plan->forward_plan_xdir, workspace_complex + i * ngridz * ngridx_long + j, data_complex + i * ngridz * ngridx_long + j);
373 prod = slabsy * ngridz;
375 for(n = 0; n < prod; n++)
381 (plan->backward_plan_xdir, data_complex + i * ngridz * ngridx_long + j, workspace_complex + i * ngridz * ngridx_long + j);
384 my_slab_transpose(workspace_complex, data_complex, plan->slabs_x_per_task, plan->first_slab_x_of_task, plan->slabs_y_per_task,
385 plan->first_slab_y_of_task, ngridx, ngridy, ngridz, 1);
387 prod = slabsx * ngridz;
389 for(n = 0; n < prod; n++)
395 (plan->backward_plan_ydir, data_complex + i * ngridz * ngridy_long + j, workspace_complex + i * ngridz * ngridy_long + j);
398 prod = slabsx * ngridy;
400 for(n = 0; n < prod; n++)
402 FFTW(execute_dft_c2r)(plan->backward_plan_zdir, workspace_complex + n * ngridz_long, data_real + n * ngridz2_long);
413 plan->NgridX = NgridX;
414 plan->NgridY = NgridY;
415 plan->NgridZ = NgridZ;
417 int Ngridz = NgridZ / 2 + 1;
419 plan->Ngridz = Ngridz;
420 plan->Ngrid2 = 2 * Ngridz;
426 plan->lastcol_XY = plan->firstcol_XY + plan->ncol_XY - 1;
427 plan->lastcol_XZ = plan->firstcol_XZ + plan->ncol_XZ - 1;
428 plan->lastcol_ZY = plan->firstcol_ZY + plan->ncol_ZY - 1;
433 plan->second_transposed_ncells = ((size_t)plan->NgridX) * plan->second_transposed_ncol;
435 plan->max_datasize = ((size_t)plan->Ngrid2) * plan->ncol_XY;
436 plan->max_datasize = std::max<size_t>(plan->max_datasize, 2 * ((
size_t)plan->NgridY) * plan->transposed_ncol);
437 plan->max_datasize = std::max<size_t>(plan->max_datasize, 2 * ((
size_t)plan->NgridX) * plan->second_transposed_ncol);
438 plan->max_datasize = std::max<size_t>(plan->max_datasize, ((
size_t)plan->ncol_XZ) * plan->NgridY);
439 plan->max_datasize = std::max<size_t>(plan->max_datasize, ((
size_t)plan->ncol_ZY) * plan->NgridX);
441 plan->fftsize = plan->max_datasize;
443 plan->offsets_send_A = (
size_t *)
Mem.mymalloc_movable_clear(&plan->offsets_send_A,
"offsets_send_A",
NTask *
sizeof(
size_t));
444 plan->offsets_recv_A = (
size_t *)
Mem.mymalloc_movable_clear(&plan->offsets_recv_A,
"offsets_recv_A",
NTask *
sizeof(
size_t));
445 plan->offsets_send_B = (
size_t *)
Mem.mymalloc_movable_clear(&plan->offsets_send_B,
"offsets_send_B",
NTask *
sizeof(
size_t));
446 plan->offsets_recv_B = (
size_t *)
Mem.mymalloc_movable_clear(&plan->offsets_recv_B,
"offsets_recv_B",
NTask *
sizeof(
size_t));
447 plan->offsets_send_C = (
size_t *)
Mem.mymalloc_movable_clear(&plan->offsets_send_C,
"offsets_send_C",
NTask *
sizeof(
size_t));
448 plan->offsets_recv_C = (
size_t *)
Mem.mymalloc_movable_clear(&plan->offsets_recv_C,
"offsets_recv_C",
NTask *
sizeof(
size_t));
449 plan->offsets_send_D = (
size_t *)
Mem.mymalloc_movable_clear(&plan->offsets_send_D,
"offsets_send_D",
NTask *
sizeof(
size_t));
450 plan->offsets_recv_D = (
size_t *)
Mem.mymalloc_movable_clear(&plan->offsets_recv_D,
"offsets_recv_D",
NTask *
sizeof(
size_t));
452 plan->count_send_A = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_send_A,
"count_send_A",
NTask *
sizeof(
size_t));
453 plan->count_recv_A = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_recv_A,
"count_recv_A",
NTask *
sizeof(
size_t));
454 plan->count_send_B = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_send_B,
"count_send_B",
NTask *
sizeof(
size_t));
455 plan->count_recv_B = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_recv_B,
"count_recv_B",
NTask *
sizeof(
size_t));
456 plan->count_send_C = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_send_C,
"count_send_C",
NTask *
sizeof(
size_t));
457 plan->count_recv_C = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_recv_C,
"count_recv_C",
NTask *
sizeof(
size_t));
458 plan->count_send_D = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_send_D,
"count_send_D",
NTask *
sizeof(
size_t));
459 plan->count_recv_D = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_recv_D,
"count_recv_D",
NTask *
sizeof(
size_t));
460 plan->count_send_13 = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_send_13,
"count_send_13",
NTask *
sizeof(
size_t));
461 plan->count_recv_13 = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_recv_13,
"count_recv_13",
NTask *
sizeof(
size_t));
462 plan->count_send_23 = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_send_23,
"count_send_23",
NTask *
sizeof(
size_t));
463 plan->count_recv_23 = (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_recv_23,
"count_recv_23",
NTask *
sizeof(
size_t));
464 plan->count_send_13back =
465 (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_send_13back,
"count_send_13back",
NTask *
sizeof(
size_t));
466 plan->count_recv_13back =
467 (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_recv_13back,
"count_recv_13back",
NTask *
sizeof(
size_t));
468 plan->count_send_23back =
469 (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_send_23back,
"count_send_23back",
NTask *
sizeof(
size_t));
470 plan->count_recv_23back =
471 (
size_t *)
Mem.mymalloc_movable_clear(&plan->count_recv_23back,
"count_recv_23back",
NTask *
sizeof(
size_t));
473 int dimA[3] = {plan->NgridX, plan->NgridY, plan->Ngridz};
474 int permA[3] = {0, 2, 1};
476 my_fft_column_remap(NULL, dimA, plan->firstcol_XY, plan->ncol_XY, NULL, permA, plan->transposed_firstcol, plan->transposed_ncol,
477 plan->offsets_send_A, plan->offsets_recv_A, plan->count_send_A, plan->count_recv_A, 1);
479 int dimB[3] = {plan->NgridX, plan->Ngridz, plan->NgridY};
480 int permB[3] = {2, 1, 0};
482 my_fft_column_remap(NULL, dimB, plan->transposed_firstcol, plan->transposed_ncol, NULL, permB, plan->second_transposed_firstcol,
483 plan->second_transposed_ncol, plan->offsets_send_B, plan->offsets_recv_B, plan->count_send_B, plan->count_recv_B,
486 int dimC[3] = {plan->NgridY, plan->Ngridz, plan->NgridX};
487 int permC[3] = {2, 1, 0};
489 my_fft_column_remap(NULL, dimC, plan->second_transposed_firstcol, plan->second_transposed_ncol, NULL, permC,
490 plan->transposed_firstcol, plan->transposed_ncol, plan->offsets_send_C, plan->offsets_recv_C, plan->count_send_C,
491 plan->count_recv_C, 1);
493 int dimD[3] = {plan->NgridX, plan->Ngridz, plan->NgridY};
494 int permD[3] = {0, 2, 1};
496 my_fft_column_remap(NULL, dimD, plan->transposed_firstcol, plan->transposed_ncol, NULL, permD, plan->firstcol_XY, plan->ncol_XY,
497 plan->offsets_send_D, plan->offsets_recv_D, plan->count_send_D, plan->count_recv_D, 1);
499 int dim23[3] = {plan->NgridX, plan->NgridY, plan->Ngrid2};
500 int perm23[3] = {0, 2, 1};
502 my_fft_column_transpose(NULL, dim23, plan->firstcol_XY, plan->ncol_XY, NULL, perm23, plan->firstcol_XZ, plan->ncol_XZ,
503 plan->count_send_23, plan->count_recv_23, 1);
505 int dim23back[3] = {plan->NgridX, plan->Ngrid2, plan->NgridY};
506 int perm23back[3] = {0, 2, 1};
508 my_fft_column_transpose(NULL, dim23back, plan->firstcol_XZ, plan->ncol_XZ, NULL, perm23back, plan->firstcol_XY, plan->ncol_XY,
509 plan->count_send_23back, plan->count_recv_23back, 1);
511 int dim13[3] = {plan->NgridX, plan->NgridY, plan->Ngrid2};
512 int perm13[3] = {2, 1, 0};
514 my_fft_column_transpose(NULL, dim13, plan->firstcol_XY, plan->ncol_XY, NULL, perm13, plan->firstcol_ZY, plan->ncol_ZY,
515 plan->count_send_13, plan->count_recv_13, 1);
517 int dim13back[3] = {plan->Ngrid2, plan->NgridY, plan->NgridX};
518 int perm13back[3] = {2, 1, 0};
520 my_fft_column_transpose(NULL, dim13back, plan->firstcol_ZY, plan->ncol_ZY, NULL, perm13back, plan->firstcol_XY, plan->ncol_XY,
521 plan->count_send_13back, plan->count_recv_13back, 1);
526 Mem.myfree(plan->count_recv_23back);
527 Mem.myfree(plan->count_send_23back);
528 Mem.myfree(plan->count_recv_13back);
529 Mem.myfree(plan->count_send_13back);
530 Mem.myfree(plan->count_recv_23);
531 Mem.myfree(plan->count_send_23);
532 Mem.myfree(plan->count_recv_13);
533 Mem.myfree(plan->count_send_13);
534 Mem.myfree(plan->count_recv_D);
535 Mem.myfree(plan->count_send_D);
536 Mem.myfree(plan->count_recv_C);
537 Mem.myfree(plan->count_send_C);
538 Mem.myfree(plan->count_recv_B);
539 Mem.myfree(plan->count_send_B);
540 Mem.myfree(plan->count_recv_A);
541 Mem.myfree(plan->count_send_A);
543 Mem.myfree(plan->offsets_recv_D);
544 Mem.myfree(plan->offsets_send_D);
545 Mem.myfree(plan->offsets_recv_C);
546 Mem.myfree(plan->offsets_send_C);
547 Mem.myfree(plan->offsets_recv_B);
548 Mem.myfree(plan->offsets_send_B);
549 Mem.myfree(plan->offsets_recv_A);
550 Mem.myfree(plan->offsets_send_A);
555 int dim23[3] = {plan->NgridX, plan->NgridY, plan->Ngrid2};
556 int perm23[3] = {0, 2, 1};
558 my_fft_column_transpose(data, dim23, plan->firstcol_XY, plan->ncol_XY, out, perm23, plan->firstcol_XZ, plan->ncol_XZ,
559 plan->count_send_23, plan->count_recv_23, 0);
564 int dim23back[3] = {plan->NgridX, plan->Ngrid2, plan->NgridY};
565 int perm23back[3] = {0, 2, 1};
567 my_fft_column_transpose(data, dim23back, plan->firstcol_XZ, plan->ncol_XZ, out, perm23back, plan->firstcol_XY, plan->ncol_XY,
568 plan->count_send_23back, plan->count_recv_23back, 0);
573 int dim13[3] = {plan->NgridX, plan->NgridY, plan->Ngrid2};
574 int perm13[3] = {2, 1, 0};
576 my_fft_column_transpose(data, dim13, plan->firstcol_XY, plan->ncol_XY, out, perm13, plan->firstcol_ZY, plan->ncol_ZY,
577 plan->count_send_13, plan->count_recv_13, 0);
582 int dim13back[3] = {plan->Ngrid2, plan->NgridY, plan->NgridX};
583 int perm13back[3] = {2, 1, 0};
585 my_fft_column_transpose(data, dim13back, plan->firstcol_ZY, plan->ncol_ZY, out, perm13back, plan->firstcol_XY, plan->ncol_XY,
586 plan->count_send_13back, plan->count_recv_13back, 0);
592 fft_real *data_real = (fft_real *)data, *workspace_real = (fft_real *)workspace;
593 fft_complex *data_complex = (fft_complex *)data, *workspace_complex = (fft_complex *)workspace;
598 for(n = 0; n < plan->ncol_XY; n++)
599 FFTW(execute_dft_r2c)(plan->forward_plan_zdir, data_real + n * plan->Ngrid2, workspace_complex + n * plan->Ngridz);
601 int dimA[3] = {plan->NgridX, plan->NgridY, plan->Ngridz};
602 int permA[3] = {0, 2, 1};
604 my_fft_column_remap(workspace_complex, dimA, plan->firstcol_XY, plan->ncol_XY, data_complex, permA, plan->transposed_firstcol,
605 plan->transposed_ncol, plan->offsets_send_A, plan->offsets_recv_A, plan->count_send_A, plan->count_recv_A,
609 for(n = 0; n < plan->transposed_ncol; n++)
610 FFTW(execute_dft)(plan->forward_plan_ydir, data_complex + n * plan->NgridY, workspace_complex + n * plan->NgridY);
612 int dimB[3] = {plan->NgridX, plan->Ngridz, plan->NgridY};
613 int permB[3] = {2, 1, 0};
615 my_fft_column_remap(workspace_complex, dimB, plan->transposed_firstcol, plan->transposed_ncol, data_complex, permB,
616 plan->second_transposed_firstcol, plan->second_transposed_ncol, plan->offsets_send_B, plan->offsets_recv_B,
617 plan->count_send_B, plan->count_recv_B, 0);
620 for(n = 0; n < plan->second_transposed_ncol; n++)
621 FFTW(execute_dft)(plan->forward_plan_xdir, data_complex + n * plan->NgridX, workspace_complex + n * plan->NgridX);
628 for(n = 0; n < plan->second_transposed_ncol; n++)
629 FFTW(execute_dft)(plan->backward_plan_xdir, data_complex + n * plan->NgridX, workspace_complex + n * plan->NgridX);
631 int dimC[3] = {plan->NgridY, plan->Ngridz, plan->NgridX};
632 int permC[3] = {2, 1, 0};
634 my_fft_column_remap(workspace_complex, dimC, plan->second_transposed_firstcol, plan->second_transposed_ncol, data_complex, permC,
635 plan->transposed_firstcol, plan->transposed_ncol, plan->offsets_send_C, plan->offsets_recv_C,
636 plan->count_send_C, plan->count_recv_C, 0);
639 for(n = 0; n < plan->transposed_ncol; n++)
640 FFTW(execute_dft)(plan->backward_plan_ydir, data_complex + n * plan->NgridY, workspace_complex + n * plan->NgridY);
642 int dimD[3] = {plan->NgridX, plan->Ngridz, plan->NgridY};
643 int permD[3] = {0, 2, 1};
645 my_fft_column_remap(workspace_complex, dimD, plan->transposed_firstcol, plan->transposed_ncol, data_complex, permD,
646 plan->firstcol_XY, plan->ncol_XY, plan->offsets_send_D, plan->offsets_recv_D, plan->count_send_D,
647 plan->count_recv_D, 0);
650 for(n = 0; n < plan->ncol_XY; n++)
651 FFTW(execute_dft_c2r)(plan->backward_plan_zdir, data_complex + n * plan->Ngridz, workspace_real + n * plan->Ngrid2);
655void pm_mpi_fft::my_fft_column_remap(fft_complex *data,
int Ndims[3],
656 int in_firstcol,
int in_ncol,
657 fft_complex *out,
int perm[3],
int out_firstcol,
int out_ncol,
size_t *offset_send,
658 size_t *offset_recv,
size_t *count_send,
size_t *count_recv,
size_t just_count_flag)
660 int j, target, origin, ngrp, recvTask, perm_rev[3], xyz[3], uvw[3];
661 size_t nimport, nexport;
664 for(j = 0; j < 3; j++)
665 perm_rev[j] = perm[j];
667 if(!(perm_rev[perm[0]] == 0 && perm_rev[perm[1]] == 1 && perm_rev[perm[2]] == 2))
669 for(j = 0; j < 3; j++)
670 perm_rev[j] = perm[perm[j]];
672 if(!(perm_rev[perm[0]] == 0 && perm_rev[perm[1]] == 1 && perm_rev[perm[2]] == 2))
676 int in_colums = Ndims[0] * Ndims[1];
677 int in_avg = (in_colums - 1) /
NTask + 1;
678 int in_exc =
NTask * in_avg - in_colums;
679 int in_tasklastsection =
NTask - in_exc;
680 int in_pivotcol = in_tasklastsection * in_avg;
682 int out_colums = Ndims[perm[0]] * Ndims[perm[1]];
683 int out_avg = (out_colums - 1) /
NTask + 1;
684 int out_exc =
NTask * out_avg - out_colums;
685 int out_tasklastsection =
NTask - out_exc;
686 int out_pivotcol = out_tasklastsection * out_avg;
688 size_t i, ncells = ((size_t)in_ncol) * Ndims[2];
690 xyz[0] = in_firstcol / Ndims[1];
691 xyz[1] = in_firstcol % Ndims[1];
694 memset(count_send, 0,
NTask *
sizeof(
size_t));
697 for(i = 0; i < ncells; i++)
700 uvw[0] = xyz[perm[0]];
701 uvw[1] = xyz[perm[1]];
702 uvw[2] = xyz[perm[2]];
704 int newcol = Ndims[perm[1]] * uvw[0] + uvw[1];
705 if(newcol < out_pivotcol)
706 target = newcol / out_avg;
708 target = (newcol - out_pivotcol) / (out_avg - 1) + out_tasklastsection;
713 count_send[target]++;
716 size_t off = offset_send[target] + count_send[target]++;
717 out[off][0] = data[i][0];
718 out[off][1] = data[i][1];
721 if(xyz[2] == Ndims[2])
725 if(xyz[1] == Ndims[1])
737 for(j = 0, nimport = 0, nexport = 0, offset_send[0] = 0, offset_recv[0] = 0; j <
NTask; j++)
739 nexport += count_send[j];
740 nimport += count_recv[j];
744 offset_send[j] = offset_send[j - 1] + count_send[j - 1];
745 offset_recv[j] = offset_recv[j - 1] + count_recv[j - 1];
749 if(nexport != ncells)
750 Terminate(
"nexport=%lld != ncells=%lld", (
long long)nexport, (
long long)ncells);
757 for(ngrp = 0; ngrp < (1 <<
PTask); ngrp++)
763 if(count_send[recvTask] > 0 || count_recv[recvTask] > 0)
764 myMPI_Sendrecv(&out[offset_send[recvTask]], count_send[recvTask] *
sizeof(fft_complex), MPI_BYTE, recvTask,
TAG_DENS_A,
765 &data[offset_recv[recvTask]], count_recv[recvTask] *
sizeof(fft_complex), MPI_BYTE, recvTask,
768 nimport += count_recv[recvTask];
775 int first[3], last[3];
777 first[0] = out_firstcol / Ndims[perm[1]];
778 first[1] = out_firstcol % Ndims[perm[1]];
781 last[0] = (out_firstcol + out_ncol - 1) / Ndims[perm[1]];
782 last[1] = (out_firstcol + out_ncol - 1) % Ndims[perm[1]];
783 last[2] = Ndims[perm[2]] - 1;
785 if(first[1] + out_ncol >= Ndims[perm[1]])
788 last[1] = Ndims[perm[1]] - 1;
793 int xyz_first[3], xyz_last[3];
795 for(j = 0; j < 3; j++)
797 xyz_first[j] = first[perm_rev[j]];
798 xyz_last[j] = last[perm_rev[j]];
801 memset(count_recv, 0,
NTask *
sizeof(
size_t));
806 for(xyz[0] = xyz_first[0]; xyz[0] <= xyz_last[0]; xyz[0]++)
807 for(xyz[1] = xyz_first[1]; xyz[1] <= xyz_last[1]; xyz[1]++)
808 for(xyz[2] = xyz_first[2]; xyz[2] <= xyz_last[2]; xyz[2]++)
811 uvw[0] = xyz[perm[0]];
812 uvw[1] = xyz[perm[1]];
813 uvw[2] = xyz[perm[2]];
815 int col = uvw[0] * Ndims[perm[1]] + uvw[1];
817 if(col >= out_firstcol && col < out_firstcol + out_ncol)
820 int newcol = Ndims[1] * xyz[0] + xyz[1];
821 if(newcol < in_pivotcol)
822 origin = newcol / in_avg;
824 origin = (newcol - in_pivotcol) / (in_avg - 1) + in_tasklastsection;
826 size_t index = ((size_t)Ndims[perm[2]]) * (col - out_firstcol) + uvw[2];
829 size_t off = offset_recv[origin] + count_recv[origin]++;
830 out[index][0] = data[off][0];
831 out[index][1] = data[off][1];
839 int fi = out_firstcol % Ndims[perm[1]];
840 int la = (out_firstcol + out_ncol - 1) % Ndims[perm[1]];
842 Terminate(
"count=%lld nimport=%lld ncol=%d fi=%d la=%d first=%d last=%d\n", (
long long)count, (
long long)nimport, out_ncol,
843 fi, la, first[1], last[1]);
848void pm_mpi_fft::my_fft_column_transpose(fft_real *data,
int Ndims[3],
849 int in_firstcol,
int in_ncol,
850 fft_real *out,
int perm[3],
int out_firstcol,
int out_ncol,
size_t *count_send,
851 size_t *count_recv,
size_t just_count_flag)
855 for(
int j = 0; j < 3; j++)
856 perm_rev[j] = perm[j];
858 if(!(perm_rev[perm[0]] == 0 && perm_rev[perm[1]] == 1 && perm_rev[perm[2]] == 2))
860 for(
int j = 0; j < 3; j++)
861 perm_rev[j] = perm[perm[j]];
863 if(!(perm_rev[perm[0]] == 0 && perm_rev[perm[1]] == 1 && perm_rev[perm[2]] == 2))
867 int in_colums = Ndims[0] * Ndims[1];
868 int in_avg = (in_colums - 1) /
NTask + 1;
869 int in_exc =
NTask * in_avg - in_colums;
870 int in_tasklastsection =
NTask - in_exc;
871 int in_pivotcol = in_tasklastsection * in_avg;
873 int out_colums = Ndims[perm[0]] * Ndims[perm[1]];
874 int out_avg = (out_colums - 1) /
NTask + 1;
875 int out_exc =
NTask * out_avg - out_colums;
876 int out_tasklastsection =
NTask - out_exc;
877 int out_pivotcol = out_tasklastsection * out_avg;
880 memset(count_send, 0,
NTask *
sizeof(
size_t));
883 for(
int ngrp = 0; ngrp < (1 <<
PTask); ngrp++)
890 if(count_send[target] == 0 && count_recv[target] == 0 && just_count_flag == 0)
895 source_first[0] = in_firstcol / Ndims[1];
896 source_first[1] = in_firstcol % Ndims[1];
900 source_last[0] = (in_firstcol + in_ncol - 1) / Ndims[1];
901 source_last[1] = (in_firstcol + in_ncol - 1) % Ndims[1];
902 source_last[2] = Ndims[2] - 1;
904 if(source_first[1] + in_ncol >= Ndims[1])
907 source_last[1] = Ndims[1] - 1;
912 int target_first_col = 0;
913 int long target_num_col = 0;
915 if(target < out_tasklastsection)
917 target_first_col = target * out_avg;
918 target_num_col = out_avg;
922 target_first_col = (target - out_tasklastsection) * (out_avg - 1) + out_pivotcol;
923 target_num_col = (out_avg - 1);
927 int first[3], last[3];
929 first[0] = target_first_col / Ndims[perm[1]];
930 first[1] = target_first_col % Ndims[perm[1]];
933 last[0] = (target_first_col + target_num_col - 1) / Ndims[perm[1]];
934 last[1] = (target_first_col + target_num_col - 1) % Ndims[perm[1]];
935 last[2] = Ndims[perm[2]] - 1;
937 if(first[1] + target_num_col >= Ndims[perm[1]])
940 last[1] = Ndims[perm[1]] - 1;
944 int xyz_first[3], xyz_last[3];
946 for(
int j = 0; j < 3; j++)
948 xyz_first[j] = first[perm_rev[j]];
949 xyz_last[j] = last[perm_rev[j]];
953 int xyz_start[3], xyz_end[3];
954 for(
int j = 0; j < 3; j++)
956 xyz_start[j] = std::max<int>(xyz_first[j], source_first[j]);
957 xyz_end[j] = std::min<int>(xyz_last[j], source_last[j]);
961 for(
int j = 0; j < 3; j++)
962 xyz[j] = xyz_start[j];
966 int flip_in_firstcol = 0;
967 int flip_in_ncol = 0;
969 if(target < in_tasklastsection)
971 flip_in_firstcol = target * in_avg;
972 flip_in_ncol = in_avg;
976 flip_in_firstcol = (target - in_tasklastsection) * (in_avg - 1) + in_pivotcol;
977 flip_in_ncol = (in_avg - 1);
981 int flip_source_first[3];
982 flip_source_first[0] = flip_in_firstcol / Ndims[1];
983 flip_source_first[1] = flip_in_firstcol % Ndims[1];
984 flip_source_first[2] = 0;
986 int flip_source_last[3];
987 flip_source_last[0] = (flip_in_firstcol + flip_in_ncol - 1) / Ndims[1];
988 flip_source_last[1] = (flip_in_firstcol + flip_in_ncol - 1) % Ndims[1];
989 flip_source_last[2] = Ndims[2] - 1;
991 if(flip_source_first[1] + flip_in_ncol >= Ndims[1])
993 flip_source_first[1] = 0;
994 flip_source_last[1] = Ndims[1] - 1;
999 int flip_first_col = 0;
1000 int flip_num_col = 0;
1004 flip_first_col =
ThisTask * out_avg;
1005 flip_num_col = out_avg;
1009 flip_first_col = (
ThisTask - out_tasklastsection) * (out_avg - 1) + out_pivotcol;
1010 flip_num_col = (out_avg - 1);
1014 int flip_first[3], flip_last[3];
1016 flip_first[0] = flip_first_col / Ndims[perm[1]];
1017 flip_first[1] = flip_first_col % Ndims[perm[1]];
1020 flip_last[0] = (flip_first_col + flip_num_col - 1) / Ndims[perm[1]];
1021 flip_last[1] = (flip_first_col + flip_num_col - 1) % Ndims[perm[1]];
1022 flip_last[2] = Ndims[perm[2]] - 1;
1024 if(flip_first[1] + flip_num_col >= Ndims[perm[1]])
1027 flip_last[1] = Ndims[perm[1]] - 1;
1031 int abc_first[3], abc_last[3];
1033 for(
int j = 0; j < 3; j++)
1035 abc_first[j] = flip_first[perm_rev[j]];
1036 abc_last[j] = flip_last[perm_rev[j]];
1040 int abc_start[3], abc_end[3];
1041 for(
int j = 0; j < 3; j++)
1043 abc_start[j] = std::max<int>(abc_first[j], flip_source_first[j]);
1044 abc_end[j] = std::min<int>(abc_last[j], flip_source_last[j]);
1049 for(
int j = 0; j < 3; j++)
1050 abc[j] = abc_start[j];
1052 size_t tot_count_send = 0;
1053 size_t tot_count_recv = 0;
1056 size_t parnter_freebytes;
1060 size_t freeb = std::min<size_t>(parnter_freebytes,
Mem.
FreeBytes);
1062 size_t limit = 0.5 * freeb / (
sizeof(fft_real) +
sizeof(fft_real));
1070 size_t limit_send = count_send[target] - tot_count_send;
1071 size_t limit_recv = count_recv[target] - tot_count_recv;
1075 limit_send = SIZE_MAX;
1076 limit_recv = SIZE_MAX;
1080 if(limit_send > limit)
1083 if(limit_recv > limit)
1087 fft_real *buffer_send = NULL;
1088 fft_real *buffer_recv = NULL;
1090 if(just_count_flag == 0)
1092 buffer_send = (fft_real *)
Mem.mymalloc(
"buffer_send", limit_send *
sizeof(fft_real));
1093 buffer_recv = (fft_real *)
Mem.mymalloc(
"buffer_recv", limit_recv *
sizeof(fft_real));
1099 while(count < limit_send && xyz[0] <= xyz_end[0] && xyz[1] <= xyz_end[1] && xyz[2] <= xyz_end[2])
1102 int col_old = xyz[0] * Ndims[1] + xyz[1];
1104 if(col_old >= in_firstcol && col_old < in_firstcol + in_ncol)
1108 uvw[0] = xyz[perm[0]];
1109 uvw[1] = xyz[perm[1]];
1110 uvw[2] = xyz[perm[2]];
1112 int col_new = uvw[0] * Ndims[perm[1]] + uvw[1];
1114 if(col_new >= target_first_col && col_new < target_first_col + target_num_col)
1119 count_send[target]++;
1122 long long source_cell = (Ndims[1] * xyz[0] + xyz[1] - in_firstcol) * Ndims[2] + xyz[2];
1124 buffer_send[count++] = data[source_cell];
1131 if(xyz[2] > xyz_end[2])
1133 xyz[2] = xyz_start[2];
1135 if(xyz[1] > xyz_end[1])
1137 xyz[1] = xyz_start[1];
1143 if(just_count_flag == 0)
1149 while(count < limit_recv && abc[0] <= abc_end[0] && abc[1] <= abc_end[1] && abc[2] <= abc_end[2])
1152 int col_old = abc[0] * Ndims[1] + abc[1];
1154 if(col_old >= flip_in_firstcol && col_old < flip_in_firstcol + flip_in_ncol)
1158 uvw[0] = abc[perm[0]];
1159 uvw[1] = abc[perm[1]];
1160 uvw[2] = abc[perm[2]];
1162 int col_new = uvw[0] * Ndims[perm[1]] + uvw[1];
1164 if(col_new >= flip_first_col && col_new < flip_first_col + flip_num_col)
1168 long long target_cell = (Ndims[perm[1]] * uvw[0] + uvw[1] - flip_first_col) * Ndims[perm[2]] + uvw[2];
1170 out[target_cell] = buffer_recv[count++];
1176 if(abc[2] > abc_end[2])
1178 abc[2] = abc_start[2];
1180 if(abc[1] > abc_end[1])
1182 abc[1] = abc_start[1];
1188 Mem.myfree(buffer_recv);
1189 Mem.myfree(buffer_send);
1197 Terminate(
"high number of iterations: limit=%lld", (
long long)limit);
1199 while(tot_count_send < count_send[target] || tot_count_recv < count_recv[target]);
void my_slab_based_fft_free(fft_plan *plan)
void my_column_based_fft_free(fft_plan *plan)
void my_column_based_fft_init(fft_plan *plan, int NgridX, int NgridY, int NgridZ)
void my_column_based_fft(fft_plan *plan, void *data, void *workspace, int forward)
void my_fft_swap13back(fft_plan *plan, fft_real *data, fft_real *out)
void my_fft_swap13(fft_plan *plan, fft_real *data, fft_real *out)
void my_slab_transposeB(fft_plan *plan, fft_real *field, fft_real *scratch)
void my_slab_based_fft(fft_plan *plan, void *data, void *workspace, int forward)
void my_fft_swap23(fft_plan *plan, fft_real *data, fft_real *out)
void my_fft_swap23back(fft_plan *plan, fft_real *data, fft_real *out)
void my_slab_transposeA(fft_plan *plan, fft_real *field, fft_real *scratch)
void my_slab_based_fft_init(fft_plan *plan, int NgridX, int NgridY, int NgridZ)
#define MPI_MESSAGE_SIZELIMIT_IN_BYTES
int myMPI_Alltoall(const void *sendbuf, int sendcount, MPI_Datatype sendtype, void *recvbuf, int recvcount, MPI_Datatype recvtype, MPI_Comm comm)
int myMPI_Sendrecv(void *sendbuf, size_t sendcount, MPI_Datatype sendtype, int dest, int sendtag, void *recvbuf, size_t recvcount, MPI_Datatype recvtype, int source, int recvtag, MPI_Comm comm, MPI_Status *status)
void myMPI_Alltoallv(void *sendbuf, size_t *sendcounts, size_t *sdispls, void *recvbuf, size_t *recvcounts, size_t *rdispls, int len, int big_flag, MPI_Comm comm)
void subdivide_evenly(long long N, int pieces, int index_bin, long long *first, int *count)