diff --git a/src/fof.c b/src/fof.c
index 38c23109a5c92cecba31814395a39d873e8e739f..29b1a7324c3bdd2806142843ac1f379a7fac9d26 100644
--- a/src/fof.c
+++ b/src/fof.c
@@ -324,8 +324,8 @@ __attribute__((always_inline)) INLINE static size_t fof_find(
 
 /* Updates the root and checks that its value has not been changed since being
  * read. */
-__attribute__((always_inline)) INLINE static size_t update_root(
-    volatile size_t *address, size_t y) {
+__attribute__((always_inline)) INLINE static int update_root(
+    volatile size_t *address, const size_t y) {
 
   size_t *size_t_ptr = (size_t *)address;
 
@@ -677,130 +677,136 @@ void fof_search_pair_cells(const struct fof_props *props, const struct space *s,
 
 /* Perform a FOF search between a local and foreign cell using the Union-Find
  * algorithm. Store any links found between particles.*/
-void fof_search_pair_cells_foreign(const struct fof_props *props,
-                                   const struct space *s, const double dim[3],
-                                   const double l_x2, struct cell *restrict ci,
-                                   struct cell *restrict cj,
-                                   int *restrict link_count,
-                                   struct fof_mpi **group_links,
-                                   int *restrict group_links_size) {
+void fof_search_pair_cells_foreign(
+    const struct fof_props *props, const struct space *s, const double dim[3],
+    const double l_x2, const struct cell *restrict ci,
+    const struct cell *restrict cj, int *restrict link_count,
+    struct fof_mpi **group_links, int *restrict group_links_size) {
+
 #ifdef WITH_MPI
   const size_t count_i = ci->grav.count;
   const size_t count_j = cj->grav.count;
-  struct gpart *gparts_i = ci->grav.parts;
-  struct gpart *gparts_j = cj->grav.parts;
+  const struct gpart *gparts_i = ci->grav.parts;
+  const struct gpart *gparts_j = cj->grav.parts;
 
   /* Get local pointers */
   size_t *group_index = props->group_index;
   size_t *group_size = props->group_size;
 
+  /* Values local to this function to avoid dereferencing */
+  struct fof_mpi *local_group_links = *group_links;
+  int local_link_count = *link_count;
+
   /* Make a list of particle offsets into the global gparts array. */
   size_t *const offset_i = group_index + (ptrdiff_t)(gparts_i - s->gparts);
 
-  /* Account for boundary conditions.*/
-  double shift[3] = {0.0, 0.0, 0.0};
+#ifdef SWIFT_DEBUG_CHECKS
 
   /* Check whether cells are local to the node. */
   const int ci_local = (ci->nodeID == engine_rank);
   const int cj_local = (cj->nodeID == engine_rank);
 
-#ifdef SWIFT_DEBUG_CHECKS
   if ((ci_local && cj_local) || (!ci_local && !cj_local))
     error(
         "FOF search of foreign cells called on two local cells or two foreign "
         "cells.");
+
+  if (!ci_local) {
+    error("Cell ci, is not local.");
+  }
 #endif
 
   /* Get the relative distance between the pairs, wrapping. */
   const int periodic = s->periodic;
-  double diff[3];
-
-  if (ci_local) {
+  double shift[3] = {0.0, 0.0, 0.0};
 
-    for (int k = 0; k < 3; k++) {
-      diff[k] = cj->loc[k] - ci->loc[k];
-      if (periodic && diff[k] < -dim[k] / 2)
-        shift[k] = dim[k];
-      else if (periodic && diff[k] > dim[k] / 2)
-        shift[k] = -dim[k];
-      else
-        shift[k] = 0.0;
-      diff[k] += shift[k];
-    }
+  for (int k = 0; k < 3; k++) {
+    const double diff = cj->loc[k] - ci->loc[k];
+    if (periodic && diff < -dim[k] / 2)
+      shift[k] = dim[k];
+    else if (periodic && diff > dim[k] / 2)
+      shift[k] = -dim[k];
+    else
+      shift[k] = 0.0;
+  }
 
-    /* Loop over particles and find which particles belong in the same group. */
-    for (size_t i = 0; i < count_i; i++) {
+  /* Loop over particles and find which particles belong in the same group. */
+  for (size_t i = 0; i < count_i; i++) {
 
-      struct gpart *pi = &gparts_i[i];
-      const double pix = pi->x[0] - shift[0];
-      const double piy = pi->x[1] - shift[1];
-      const double piz = pi->x[2] - shift[2];
+    const struct gpart *pi = &gparts_i[i];
+    const double pix = pi->x[0] - shift[0];
+    const double piy = pi->x[1] - shift[1];
+    const double piz = pi->x[2] - shift[2];
 
-      /* Find the root of pi. */
-      const size_t root_i =
-          fof_find_global(offset_i[i] - node_offset, group_index, s->nr_gparts);
+    /* Find the root of pi. */
+    const size_t root_i =
+        fof_find_global(offset_i[i] - node_offset, group_index, s->nr_gparts);
 
-      for (size_t j = 0; j < count_j; j++) {
+    for (size_t j = 0; j < count_j; j++) {
 
-        struct gpart *pj = &gparts_j[j];
-        const double pjx = pj->x[0];
-        const double pjy = pj->x[1];
-        const double pjz = pj->x[2];
+      const struct gpart *pj = &gparts_j[j];
+      const double pjx = pj->x[0];
+      const double pjy = pj->x[1];
+      const double pjz = pj->x[2];
 
-        /* Compute pairwise distance, remembering to account for boundary
-         * conditions. */
-        float dx[3], r2 = 0.0f;
-        dx[0] = pix - pjx;
-        dx[1] = piy - pjy;
-        dx[2] = piz - pjz;
+      /* Compute pairwise distance, remembering to account for boundary
+       * conditions. */
+      float dx[3], r2 = 0.0f;
+      dx[0] = pix - pjx;
+      dx[1] = piy - pjy;
+      dx[2] = piz - pjz;
 
-        for (int k = 0; k < 3; k++) r2 += dx[k] * dx[k];
+      for (int k = 0; k < 3; k++) r2 += dx[k] * dx[k];
 
-        /* Hit or miss? */
-        if (r2 < l_x2) {
+      /* Hit or miss? */
+      if (r2 < l_x2) {
 
-          int found = 0;
+        int found = 0;
 
-          /* Check that the links have not already been added to the list. */
-          for (int l = 0; l < *link_count; l++) {
-            if ((*group_links)[l].group_i == root_i &&
-                (*group_links)[l].group_j == pj->group_id) {
-              found = 1;
-              break;
-            }
+        /* Check that the links have not already been added to the list. */
+        for (int l = 0; l < local_link_count; l++) {
+          if ((local_group_links)[l].group_i == root_i &&
+              (local_group_links)[l].group_j == pj->group_id) {
+            found = 1;
+            break;
           }
+        }
 
-          if (!found) {
+        if (!found) {
 
-            /* If the group_links array is not big enough re-allocate it. */
-            if (*link_count + 1 > *group_links_size) {
+          /* If the group_links array is not big enough re-allocate it. */
+          if (local_link_count + 1 > *group_links_size) {
 
-              int new_size = 2 * (*group_links_size);
+            const int new_size = 2 * (*group_links_size);
 
-              *group_links_size = new_size;
+            *group_links_size = new_size;
 
-              (*group_links) = (struct fof_mpi *)realloc(
-                  *group_links, new_size * sizeof(struct fof_mpi));
+            (*group_links) = (struct fof_mpi *)realloc(
+                *group_links, new_size * sizeof(struct fof_mpi));
 
-              message("Re-allocating local group links from %d to %d elements.",
-                      *link_count, new_size);
-            }
+            /* Reset the local pointer */
+            local_group_links = *group_links;
 
-            /* Store the particle group properties for communication. */
-            (*group_links)[*link_count].group_i = root_i;
-            (*group_links)[*link_count].group_i_size =
-                group_size[root_i - node_offset];
+            message("Re-allocating local group links from %d to %d elements.",
+                    local_link_count, new_size);
+          }
 
-            (*group_links)[*link_count].group_j = pj->group_id;
-            (*group_links)[*link_count].group_j_size = pj->group_size;
+          /* Store the particle group properties for communication. */
+          local_group_links[local_link_count].group_i = root_i;
+          local_group_links[local_link_count].group_i_size =
+              group_size[root_i - node_offset];
 
-            (*link_count)++;
-          }
+          local_group_links[local_link_count].group_j = pj->group_id;
+          local_group_links[local_link_count].group_j_size = pj->group_size;
+
+          local_link_count++;
         }
       }
     }
-  } else
-    error("Cell ci, is not local.");
+  }
+
+  /* Update the returned values */
+  *link_count = local_link_count;
 
 #else
   error("Calling MPI function in non-MPI mode.");
@@ -863,17 +869,22 @@ void rec_fof_search_pair(const struct fof_props *props, const struct space *s,
 
 /* Recurse on a pair of cells (one local, one foreign) and perform a FOF search
  * between cells that are within range. */
-static void rec_fof_search_pair_foreign(
-    const struct fof_props *props, const struct space *s, const double dim[3],
-    const double search_r2, struct cell *ci, struct cell *cj, int *link_count,
-    struct fof_mpi **group_links, int *group_links_size) {
+void rec_fof_search_pair_foreign(const struct fof_props *props,
+                                 const struct space *s, const double dim[3],
+                                 const double search_r2, const struct cell *ci,
+                                 const struct cell *cj,
+                                 int *restrict link_count,
+                                 struct fof_mpi **group_links,
+                                 int *restrict group_links_size) {
+
+#ifdef SWIFT_DEBUG_CHECKS
+  if (ci == cj) error("Pair FOF called on same cell!!!");
+#endif
 
   /* Find the shortest distance between cells, remembering to account for
    * boundary conditions. */
   const double r2 = cell_min_dist(ci, cj, dim);
 
-  if (ci == cj) error("Pair FOF called on same cell!!!");
-
   /* Return if cells are out of range of each other. */
   if (r2 > search_r2) return;
 
@@ -1010,8 +1021,9 @@ void fof_calc_group_size_mapper(void *map_data, int num_elements,
 }
 
 /* Mapper function to atomically update the group mass array. */
-void fof_update_group_mass_mapper(hashmap_key_t key, hashmap_value_t *value,
-                                  void *data) {
+static INLINE void fof_update_group_mass_mapper(hashmap_key_t key,
+                                                hashmap_value_t *value,
+                                                void *data) {
 
   double *group_mass = (double *)data;
 
@@ -1089,8 +1101,9 @@ void fof_unpack_group_mass_mapper(hashmap_key_t key, hashmap_value_t *value,
  */
 void fof_calc_group_mass(struct fof_props *props, const struct space *s,
                          const size_t num_groups_local,
-                         const size_t num_groups_prev, size_t *num_on_node,
-                         size_t *first_on_node, double *group_mass) {
+                         const size_t num_groups_prev,
+                         size_t *restrict num_on_node,
+                         size_t *restrict first_on_node, double *group_mass) {
 
   const size_t nr_gparts = s->nr_gparts;
   struct gpart *gparts = s->gparts;