diff --git a/src/engine.c b/src/engine.c index e5e7b78dcafd9c14a032dd3a4249e6b0a52fab4d..2d9d1b9d4ee027bbcec6e9fbd2d752dd37455860 100644 --- a/src/engine.c +++ b/src/engine.c @@ -1664,8 +1664,8 @@ int engine_estimate_nr_tasks(const struct engine *e) { * @param clean_smoothing_length_values Are we cleaning up the values of * the smoothing lengths before building the tasks ? */ -void engine_rebuild(struct engine *e, int repartitioned, - int clean_smoothing_length_values) { +void engine_rebuild(struct engine *e, const int repartitioned, + const int clean_smoothing_length_values) { const ticks tic = getticks(); @@ -1750,8 +1750,14 @@ void engine_rebuild(struct engine *e, int repartitioned, clocks_from_ticks(getticks() - tic2), clocks_getunit()); /* Re-compute the mesh forces */ - if ((e->policy & engine_policy_self_gravity) && e->s->periodic) + if ((e->policy & engine_policy_self_gravity) && e->s->periodic) { + + /* Re-allocate the PM grid if we freed it... */ + if (repartitioned) pm_mesh_allocate(e->mesh); + + /* ... and recompute */ pm_mesh_compute_potential(e->mesh, e->s, &e->threadpool, e->verbose); + } /* Re-compute the maximal RMS displacement constraint */ if (e->policy & engine_policy_cosmology) @@ -1888,6 +1894,10 @@ void engine_prepare(struct engine *e) { engine_drift_all(e, /*drift_mpole=*/0); drifted_all = 1; + /* Free the PM grid */ + if ((e->policy & engine_policy_self_gravity) && e->s->periodic) + pm_mesh_free(e->mesh); + /* And repartition */ engine_repartition(e); repartitioned = 1; diff --git a/src/mesh_gravity.c b/src/mesh_gravity.c index cbfb9ef7a2970b222bf67a915bd5044bedaaaae3..5c657a2362b1544b38dc8f5708910d1c29e6651e 100644 --- a/src/mesh_gravity.c +++ b/src/mesh_gravity.c @@ -36,6 +36,7 @@ #include "gravity_properties.h" #include "kernel_long_gravity.h" #include "part.h" +#include "restart.h" #include "runner.h" #include "space.h" #include "threadpool.h" @@ -674,6 +675,39 @@ void pm_mesh_interpolate_forces(const struct pm_mesh* mesh, #endif } +/** + * @bried Allocates the potential grid to be ready for an FFT calculation + * + * @param mesh The #pm_mesh structure. + */ +void pm_mesh_allocate(struct pm_mesh* mesh) { + + if (mesh->potential != NULL) error("Mesh already allocated!"); + + const int N = mesh->N; + + /* Allocate the memory for the combined density and potential array */ + mesh->potential = (double*)fftw_malloc(sizeof(double) * N * N * N); + if (mesh->potential == NULL) + error("Error allocating memory for the long-range gravity mesh."); + memuse_log_allocation("fftw_mesh.potential", mesh->potential, 1, + sizeof(double) * N * N * N); +} + +/** + * @brief Frees the potential grid. + * + * @param mesh The #pm_mesh structure. + */ +void pm_mesh_free(struct pm_mesh* mesh) { + + if (mesh->potential) { + memuse_log_allocation("fftw_mesh.potential", mesh->potential, 0, 0); + free(mesh->potential); + } + mesh->potential = NULL; +} + /** * @brief Initialisses the mesh used for the long-range periodic forces * @@ -704,6 +738,7 @@ void pm_mesh_init(struct pm_mesh* mesh, const struct gravity_props* props, mesh->r_s_inv = 1. / mesh->r_s; mesh->r_cut_max = mesh->r_s * props->r_cut_max_ratio; mesh->r_cut_min = mesh->r_s * props->r_cut_min_ratio; + mesh->potential = NULL; if (mesh->N > 1290) error( @@ -721,12 +756,7 @@ void pm_mesh_init(struct pm_mesh* mesh, const struct gravity_props* props, } #endif - /* Allocate the memory for the combined density and potential array */ - mesh->potential = (double*)fftw_malloc(sizeof(double) * N * N * N); - if (mesh->potential == NULL) - error("Error allocating memory for the long-range gravity mesh."); - memuse_log_allocation("fftw_mesh.potential", mesh->potential, 1, - sizeof(double) * N * N * N); + pm_mesh_allocate(mesh); #else error("No FFTW library found. Cannot compute periodic long-range forces."); @@ -766,11 +796,7 @@ void pm_mesh_clean(struct pm_mesh* mesh) { fftw_cleanup_threads(); #endif - if (mesh->potential) { - memuse_log_allocation("fftw_mesh.potential", mesh->potential, 0, 0); - free(mesh->potential); - } - mesh->potential = 0; + pm_mesh_free(mesh); } /** diff --git a/src/mesh_gravity.h b/src/mesh_gravity.h index 1b2d997398ee6f3f665340cedb790c241e641cfa..e9c07a0de0327984686d65bb9738cde643a7cab8 100644 --- a/src/mesh_gravity.h +++ b/src/mesh_gravity.h @@ -24,7 +24,6 @@ /* Local headers */ #include "gravity_properties.h" -#include "restart.h" /* Forward declarations */ struct space; @@ -77,6 +76,9 @@ void pm_mesh_interpolate_forces(const struct pm_mesh *mesh, int gcount); void pm_mesh_clean(struct pm_mesh *mesh); +void pm_mesh_allocate(struct pm_mesh *mesh); +void pm_mesh_free(struct pm_mesh *mesh); + /* Dump/restore. */ void pm_mesh_struct_dump(const struct pm_mesh *p, FILE *stream); void pm_mesh_struct_restore(struct pm_mesh *p, FILE *stream);