diff --git a/src/engine.c b/src/engine.c
index 95507065daf595d44f7909617885cf82a2309e22..dbb66869bce9fa409d0a49c6660d48b95f7093e8 100644
--- a/src/engine.c
+++ b/src/engine.c
@@ -1661,8 +1661,8 @@ int engine_estimate_nr_tasks(const struct engine *e) {
  * @param clean_smoothing_length_values Are we cleaning up the values of
  * the smoothing lengths before building the tasks ?
  */
-void engine_rebuild(struct engine *e, int repartitioned,
-                    int clean_smoothing_length_values) {
+void engine_rebuild(struct engine *e, const int repartitioned,
+                    const int clean_smoothing_length_values) {
 
   const ticks tic = getticks();
 
@@ -1746,8 +1746,14 @@ void engine_rebuild(struct engine *e, int repartitioned,
             clocks_from_ticks(getticks() - tic2), clocks_getunit());
 
   /* Re-compute the mesh forces */
-  if ((e->policy & engine_policy_self_gravity) && e->s->periodic)
+  if ((e->policy & engine_policy_self_gravity) && e->s->periodic) {
+
+    /* Re-allocate the PM grid if we freed it... */
+    if (repartitioned) pm_mesh_allocate(e->mesh);
+
+    /* ... and recompute */
     pm_mesh_compute_potential(e->mesh, e->s, &e->threadpool, e->verbose);
+  }
 
   /* Re-compute the maximal RMS displacement constraint */
   if (e->policy & engine_policy_cosmology)
@@ -1884,6 +1890,10 @@ void engine_prepare(struct engine *e) {
     engine_drift_all(e, /*drift_mpole=*/0);
     drifted_all = 1;
 
+    /* Free the PM grid */
+    if ((e->policy & engine_policy_self_gravity) && e->s->periodic)
+      pm_mesh_free(e->mesh);
+
     /* And repartition */
     engine_repartition(e);
     repartitioned = 1;
diff --git a/src/mesh_gravity.c b/src/mesh_gravity.c
index 0efec0bc80fc381d9bd6ff733cc546e1885c1d9e..303341d5356bb7f338e702e7e03cb027036b7fad 100644
--- a/src/mesh_gravity.c
+++ b/src/mesh_gravity.c
@@ -35,6 +35,7 @@
 #include "gravity_properties.h"
 #include "kernel_long_gravity.h"
 #include "part.h"
+#include "restart.h"
 #include "runner.h"
 #include "space.h"
 #include "threadpool.h"
@@ -673,6 +674,39 @@ void pm_mesh_interpolate_forces(const struct pm_mesh* mesh,
 #endif
 }
 
+/**
+ * @bried Allocates the potential grid to be ready for an FFT calculation
+ *
+ * @param mesh The #pm_mesh structure.
+ */
+void pm_mesh_allocate(struct pm_mesh* mesh) {
+
+  if (mesh->potential != NULL) error("Mesh already allocated!");
+
+  const int N = mesh->N;
+
+  /* Allocate the memory for the combined density and potential array */
+  mesh->potential = (double*)fftw_malloc(sizeof(double) * N * N * N);
+  if (mesh->potential == NULL)
+    error("Error allocating memory for the long-range gravity mesh.");
+  memuse_log_allocation("fftw_mesh.potential", mesh->potential, 1,
+                        sizeof(double) * N * N * N);
+}
+
+/**
+ * @brief Frees the potential grid.
+ *
+ * @param mesh The #pm_mesh structure.
+ */
+void pm_mesh_free(struct pm_mesh* mesh) {
+
+  if (mesh->potential) {
+    memuse_log_allocation("fftw_mesh.potential", mesh->potential, 0, 0);
+    free(mesh->potential);
+  }
+  mesh->potential = NULL;
+}
+
 /**
  * @brief Initialisses the mesh used for the long-range periodic forces
  *
@@ -703,6 +737,7 @@ void pm_mesh_init(struct pm_mesh* mesh, const struct gravity_props* props,
   mesh->r_s_inv = 1. / mesh->r_s;
   mesh->r_cut_max = mesh->r_s * props->r_cut_max_ratio;
   mesh->r_cut_min = mesh->r_s * props->r_cut_min_ratio;
+  mesh->potential = NULL;
 
   if (mesh->N > 1290)
     error(
@@ -720,12 +755,7 @@ void pm_mesh_init(struct pm_mesh* mesh, const struct gravity_props* props,
   }
 #endif
 
-  /* Allocate the memory for the combined density and potential array */
-  mesh->potential = (double*)fftw_malloc(sizeof(double) * N * N * N);
-  if (mesh->potential == NULL)
-    error("Error allocating memory for the long-range gravity mesh.");
-  memuse_log_allocation("fftw_mesh.potential", mesh->potential, 1,
-                        sizeof(double) * N * N * N);
+  pm_mesh_allocate(mesh);
 
 #else
   error("No FFTW library found. Cannot compute periodic long-range forces.");
@@ -765,11 +795,7 @@ void pm_mesh_clean(struct pm_mesh* mesh) {
   fftw_cleanup_threads();
 #endif
 
-  if (mesh->potential) {
-    memuse_log_allocation("fftw_mesh.potential", mesh->potential, 0, 0);
-    free(mesh->potential);
-  }
-  mesh->potential = 0;
+  pm_mesh_free(mesh);
 }
 
 /**
diff --git a/src/mesh_gravity.h b/src/mesh_gravity.h
index 1b2d997398ee6f3f665340cedb790c241e641cfa..e9c07a0de0327984686d65bb9738cde643a7cab8 100644
--- a/src/mesh_gravity.h
+++ b/src/mesh_gravity.h
@@ -24,7 +24,6 @@
 
 /* Local headers */
 #include "gravity_properties.h"
-#include "restart.h"
 
 /* Forward declarations */
 struct space;
@@ -77,6 +76,9 @@ void pm_mesh_interpolate_forces(const struct pm_mesh *mesh,
                                 int gcount);
 void pm_mesh_clean(struct pm_mesh *mesh);
 
+void pm_mesh_allocate(struct pm_mesh *mesh);
+void pm_mesh_free(struct pm_mesh *mesh);
+
 /* Dump/restore. */
 void pm_mesh_struct_dump(const struct pm_mesh *p, FILE *stream);
 void pm_mesh_struct_restore(struct pm_mesh *p, FILE *stream);