diff --git a/src/engine.c b/src/engine.c index 95507065daf595d44f7909617885cf82a2309e22..dbb66869bce9fa409d0a49c6660d48b95f7093e8 100644 --- a/src/engine.c +++ b/src/engine.c @@ -1661,8 +1661,8 @@ int engine_estimate_nr_tasks(const struct engine *e) { * @param clean_smoothing_length_values Are we cleaning up the values of * the smoothing lengths before building the tasks ? */ -void engine_rebuild(struct engine *e, int repartitioned, - int clean_smoothing_length_values) { +void engine_rebuild(struct engine *e, const int repartitioned, + const int clean_smoothing_length_values) { const ticks tic = getticks(); @@ -1746,8 +1746,14 @@ void engine_rebuild(struct engine *e, int repartitioned, clocks_from_ticks(getticks() - tic2), clocks_getunit()); /* Re-compute the mesh forces */ - if ((e->policy & engine_policy_self_gravity) && e->s->periodic) + if ((e->policy & engine_policy_self_gravity) && e->s->periodic) { + + /* Re-allocate the PM grid if we freed it... */ + if (repartitioned) pm_mesh_allocate(e->mesh); + + /* ... and recompute */ pm_mesh_compute_potential(e->mesh, e->s, &e->threadpool, e->verbose); + } /* Re-compute the maximal RMS displacement constraint */ if (e->policy & engine_policy_cosmology) @@ -1884,6 +1890,10 @@ void engine_prepare(struct engine *e) { engine_drift_all(e, /*drift_mpole=*/0); drifted_all = 1; + /* Free the PM grid */ + if ((e->policy & engine_policy_self_gravity) && e->s->periodic) + pm_mesh_free(e->mesh); + /* And repartition */ engine_repartition(e); repartitioned = 1; diff --git a/src/mesh_gravity.c b/src/mesh_gravity.c index 0efec0bc80fc381d9bd6ff733cc546e1885c1d9e..303341d5356bb7f338e702e7e03cb027036b7fad 100644 --- a/src/mesh_gravity.c +++ b/src/mesh_gravity.c @@ -35,6 +35,7 @@ #include "gravity_properties.h" #include "kernel_long_gravity.h" #include "part.h" +#include "restart.h" #include "runner.h" #include "space.h" #include "threadpool.h" @@ -673,6 +674,39 @@ void pm_mesh_interpolate_forces(const struct pm_mesh* mesh, #endif } +/** + * @bried Allocates the potential grid to be ready for an FFT calculation + * + * @param mesh The #pm_mesh structure. + */ +void pm_mesh_allocate(struct pm_mesh* mesh) { + + if (mesh->potential != NULL) error("Mesh already allocated!"); + + const int N = mesh->N; + + /* Allocate the memory for the combined density and potential array */ + mesh->potential = (double*)fftw_malloc(sizeof(double) * N * N * N); + if (mesh->potential == NULL) + error("Error allocating memory for the long-range gravity mesh."); + memuse_log_allocation("fftw_mesh.potential", mesh->potential, 1, + sizeof(double) * N * N * N); +} + +/** + * @brief Frees the potential grid. + * + * @param mesh The #pm_mesh structure. + */ +void pm_mesh_free(struct pm_mesh* mesh) { + + if (mesh->potential) { + memuse_log_allocation("fftw_mesh.potential", mesh->potential, 0, 0); + free(mesh->potential); + } + mesh->potential = NULL; +} + /** * @brief Initialisses the mesh used for the long-range periodic forces * @@ -703,6 +737,7 @@ void pm_mesh_init(struct pm_mesh* mesh, const struct gravity_props* props, mesh->r_s_inv = 1. / mesh->r_s; mesh->r_cut_max = mesh->r_s * props->r_cut_max_ratio; mesh->r_cut_min = mesh->r_s * props->r_cut_min_ratio; + mesh->potential = NULL; if (mesh->N > 1290) error( @@ -720,12 +755,7 @@ void pm_mesh_init(struct pm_mesh* mesh, const struct gravity_props* props, } #endif - /* Allocate the memory for the combined density and potential array */ - mesh->potential = (double*)fftw_malloc(sizeof(double) * N * N * N); - if (mesh->potential == NULL) - error("Error allocating memory for the long-range gravity mesh."); - memuse_log_allocation("fftw_mesh.potential", mesh->potential, 1, - sizeof(double) * N * N * N); + pm_mesh_allocate(mesh); #else error("No FFTW library found. Cannot compute periodic long-range forces."); @@ -765,11 +795,7 @@ void pm_mesh_clean(struct pm_mesh* mesh) { fftw_cleanup_threads(); #endif - if (mesh->potential) { - memuse_log_allocation("fftw_mesh.potential", mesh->potential, 0, 0); - free(mesh->potential); - } - mesh->potential = 0; + pm_mesh_free(mesh); } /** diff --git a/src/mesh_gravity.h b/src/mesh_gravity.h index 1b2d997398ee6f3f665340cedb790c241e641cfa..e9c07a0de0327984686d65bb9738cde643a7cab8 100644 --- a/src/mesh_gravity.h +++ b/src/mesh_gravity.h @@ -24,7 +24,6 @@ /* Local headers */ #include "gravity_properties.h" -#include "restart.h" /* Forward declarations */ struct space; @@ -77,6 +76,9 @@ void pm_mesh_interpolate_forces(const struct pm_mesh *mesh, int gcount); void pm_mesh_clean(struct pm_mesh *mesh); +void pm_mesh_allocate(struct pm_mesh *mesh); +void pm_mesh_free(struct pm_mesh *mesh); + /* Dump/restore. */ void pm_mesh_struct_dump(const struct pm_mesh *p, FILE *stream); void pm_mesh_struct_restore(struct pm_mesh *p, FILE *stream);