diff --git a/src/fof.c b/src/fof.c
index b1412c8a193eee929856b6710b49f0314269be12..a40635620dbcf22e8ccb19971907e57a63964640 100644
--- a/src/fof.c
+++ b/src/fof.c
@@ -1980,6 +1980,51 @@ void fof_dump_group_data(const struct fof_props *props,
   fclose(file);
 }
 
+struct mapper_data {
+  size_t *group_index;
+  size_t *group_size;
+  size_t nr_gparts;
+  struct gpart *space_gparts;
+};
+
+void fof_set_outgoing_root_mapper(void *map_data, int num_elements,
+                                  void *extra_data) {
+#ifdef WITH_MPI
+
+  /* Unpack the data */
+  struct cell **local_cells = (struct cell **)map_data;
+  const struct mapper_data *data = (struct mapper_data *)extra_data;
+  const size_t *const group_index = data->group_index;
+  const size_t *const group_size = data->group_size;
+  const size_t nr_gparts = data->nr_gparts;
+  const struct gpart *const space_gparts = data->space_gparts;
+
+  /* Loop over the out-going local cells */
+  for (int i = 0; i < num_elements; ++i) {
+
+    /* Get the cell and its gparts */
+    struct cell *local_cell = local_cells[i];
+    struct gpart *gparts = local_cell->grav.parts;
+
+    /* Make a list of particle offsets into the global gparts array. */
+    const size_t *const offset =
+        group_index + (ptrdiff_t)(gparts - space_gparts);
+
+    /* Set each particle's root and group properties found in the local FOF.*/
+    for (int k = 0; k < local_cell->grav.count; k++) {
+      const size_t root =
+          fof_find_global(offset[k] - node_offset, group_index, nr_gparts);
+
+      gparts[k].group_id = root;
+      gparts[k].group_size = group_size[root - node_offset];
+    }
+  }
+
+#else
+  error("Calling MPI function in non-MPI mode");
+#endif
+}
+
 /**
  * @brief Search foreign cells for links and communicate any found to the
  * appropriate node.
@@ -2087,26 +2132,33 @@ void fof_search_foreign_cells(struct fof_props *props, const struct space *s) {
   }
 
   /* Set the root of outgoing particles. */
-  for (int i = 0; i < e->nr_proxies; i++) {
 
+  /* Allocate array of outgoing cells and populate it */
+  struct cell **local_cells = malloc(num_cells_out * sizeof(struct cell *));
+  int count = 0;
+  for (int i = 0; i < e->nr_proxies; i++) {
     for (int j = 0; j < e->proxies[i].nr_cells_out; j++) {
 
-      struct cell *restrict local_cell = e->proxies[i].cells_out[j];
-      struct gpart *gparts = local_cell->grav.parts;
-
-      /* Make a list of particle offsets into the global gparts array. */
-      size_t *const offset = group_index + (ptrdiff_t)(gparts - s->gparts);
+      /* Only include gravity cells. */
+      if (e->proxies[i].cells_out_type[j] & proxy_cell_type_gravity) {
 
-      /* Set each particle's root and group properties found in the local FOF.*/
-      for (int k = 0; k < local_cell->grav.count; k++) {
-        const size_t root =
-            fof_find_global(offset[k] - node_offset, group_index, nr_gparts);
-        gparts[k].group_id = root;
-        gparts[k].group_size = group_size[root - node_offset];
+        local_cells[count] = e->proxies[i].cells_out[j];
+        ++count;
       }
     }
   }
 
+  /* Now set the roots */
+  struct mapper_data data;
+  data.group_index = group_index;
+  data.group_size = group_size;
+  data.nr_gparts = nr_gparts;
+  data.space_gparts = s->gparts;
+  threadpool_map(&e->threadpool, fof_set_outgoing_root_mapper, local_cells,
+                 num_cells_out, sizeof(struct cell **), 0, &data);
+
+  free(local_cells);
+
   if (verbose)
     message(
         "Finding local/foreign cell pairs and initialising particle roots "
@@ -2508,7 +2560,7 @@ void fof_search_tree(struct fof_props *props,
       max_group_size_local = group_size[i];
 #endif
   }
-    
+
   message(
       "Calculating the total no. of local groups took: (FOF SCALING): %.3f %s.",
       clocks_from_ticks(getticks() - tic_num_groups_calc), clocks_getunit());
@@ -2577,8 +2629,9 @@ void fof_search_tree(struct fof_props *props,
   const size_t num_groups_prev = (size_t)(ngsum - nglocal);
 #endif /* WITH_MPI */
 
-  message("Finding the total no. of groups took: (FOF SCALING): %.3f %s.",       
-            clocks_from_ticks(getticks() - tic_num_groups_calc), clocks_getunit());
+  message("Finding the total no. of groups took: (FOF SCALING): %.3f %s.",
+          clocks_from_ticks(getticks() - tic_num_groups_calc),
+          clocks_getunit());
 
   /* Set default group ID for all particles */
   for (size_t i = 0; i < nr_gparts; i++) gparts[i].group_id = group_id_default;