diff --git a/src/engine.c b/src/engine.c
index 47bee144b6f6d245003da2ca2955502ba1dc9c68..ec94c19a006ae686e1b4a6fd0d4262c09fbf2297 100644
--- a/src/engine.c
+++ b/src/engine.c
@@ -4207,7 +4207,7 @@ void engine_init_particles(struct engine *e, int flag_entropy_ICs,
   s->l_x2 = l_x * l_x;
 
   ticks tic = getticks();
-  fof_search_serial(s);
+  //fof_search_serial(s);
   message("Serial FOF search took: %.3f %s.",
           clocks_from_ticks(getticks() - tic), clocks_getunit());
 
diff --git a/src/fof.c b/src/fof.c
index fbe773cf5f0b1c6538fb62db3165778c9a549115..4fc8e42f5044a967b7aed32ea4b7addcd767dce9 100644
--- a/src/fof.c
+++ b/src/fof.c
@@ -57,23 +57,23 @@ __attribute__((always_inline)) INLINE static double cell_min_dist(
   /* Find the shortest distance between cells, remembering to account for
    * boundary conditions. */
   double dx[3], r2 = 0.0f;
-  dx[0] = min3(abs(nearest(cix - cjx, dim[0])),
-               abs(nearest(cix - (cjx + cj->width[0]), dim[0])),
-               abs(nearest((cix + ci->width[0]) - cjx, dim[0])));
+  dx[0] = min3(fabs(nearest(cix - cjx, dim[0])),
+               fabs(nearest(cix - (cjx + cj->width[0]), dim[0])),
+               fabs(nearest((cix + ci->width[0]) - cjx, dim[0])));
   dx[0] = min(
-      dx[0], abs(nearest((cix + ci->width[0]) - (cjx + cj->width[0]), dim[0])));
+      dx[0], fabs(nearest((cix + ci->width[0]) - (cjx + cj->width[0]), dim[0])));
 
-  dx[1] = min3(abs(nearest(ciy - cjy, dim[1])),
-               abs(nearest(ciy - (cjy + cj->width[1]), dim[1])),
-               abs(nearest((ciy + ci->width[1]) - cjy, dim[1])));
+  dx[1] = min3(fabs(nearest(ciy - cjy, dim[1])),
+               fabs(nearest(ciy - (cjy + cj->width[1]), dim[1])),
+               fabs(nearest((ciy + ci->width[1]) - cjy, dim[1])));
   dx[1] = min(
-      dx[1], abs(nearest((ciy + ci->width[1]) - (cjy + cj->width[1]), dim[1])));
+      dx[1], fabs(nearest((ciy + ci->width[1]) - (cjy + cj->width[1]), dim[1])));
 
-  dx[2] = min3(abs(nearest(ciz - cjz, dim[2])),
-               abs(nearest(ciz - (cjz + cj->width[2]), dim[2])),
-               abs(nearest((ciz + ci->width[2]) - cjz, dim[2])));
+  dx[2] = min3(fabs(nearest(ciz - cjz, dim[2])),
+               fabs(nearest(ciz - (cjz + cj->width[2]), dim[2])),
+               fabs(nearest((ciz + ci->width[2]) - cjz, dim[2])));
   dx[2] = min(
-      dx[2], abs(nearest((ciz + ci->width[2]) - (cjz + cj->width[2]), dim[2])));
+      dx[2], fabs(nearest((ciz + ci->width[2]) - (cjz + cj->width[2]), dim[2])));
 
   for (int k = 0; k < 3; k++) r2 += dx[k] * dx[k];
 
@@ -158,6 +158,58 @@ static void rec_fof_search(struct cell *ci, const int cid, struct space *s,
   }
 }
 
+static void rec_fof_search_ci_cj(struct cell *ci, struct cell *cj, struct space *s,
+                           int *pid, int *num_groups, const double *dim,
+                           const double search_r2) {
+
+  const double cix = ci->loc[0];
+  const double ciy = ci->loc[1];
+  const double ciz = ci->loc[2];
+
+  /* Find the shortest distance between cells, remembering to account for
+   * boundary conditions. */
+  const double r2 = cell_min_dist(ci, cj, cix, ciy, ciz, dim);
+
+  if (r2 > search_r2)
+    return;
+  /* Recurse on both cells if they are both split.
+  else if (ci->split && cj->split) {
+    for (int k = 0; k < 8; k++) {
+      if (ci->progeny[k] != NULL) {
+
+        for (int l = 0; l < 8; l++)
+          if (cj->progeny[l] != NULL)
+            rec_fof_search_ci_cj(ci->progeny[k], cj->progeny[l], s, pid, num_groups, dim, search_r2);
+      }
+    }
+  }
+  /* Perform FOF search between pairs of cells that are within the linking
+   * length and not the same cell. */
+  else if (ci != cj)
+    fof_search_pair_cells(s, ci, cj, pid, num_groups);
+
+}
+
+static void rec_fof_search_self_ci(struct cell *ci, struct space *s,
+                           int *pid, int *num_groups, const double *dim,
+                           const double search_r2) {
+
+  if (ci->split) {
+    for (int k = 0; k < 8; k++) {
+      if (ci->progeny[k] != NULL) {
+
+        rec_fof_search_self_ci(ci->progeny[k], s, pid, num_groups, dim, search_r2);
+
+        for (int l = k+1; l < 8; l++)
+          if (ci->progeny[l] != NULL)
+            rec_fof_search_ci_cj(ci->progeny[k], ci->progeny[l], s, pid, num_groups, dim, search_r2);
+      }
+    }
+  }
+  else 
+    fof_search_cell(s, ci, pid, num_groups);
+}
+
 /* Perform naive N^2 FOF search on gravity particles using the Union-Find
  * algorithm.*/
 void fof_search_serial(struct space *s) {
@@ -427,11 +479,17 @@ void fof_search_tree_serial(struct space *s) {
 
     struct cell *restrict c = &s->cells_top[cid];
 
-    message("Searching top-level cell: %ld.", cid);
-    fflush(stdout);
+    /* Perform FOF search on local particles within the cell. */
+    rec_fof_search_self_ci(c, s, pid, &num_groups, dim, search_r2);
+
+    /* Loop over all top-level cells skipping over the cells already searched.
+    */
+    for (int cjd = cid + 1; cjd < s->nr_cells; cjd++) {
+
+      struct cell *restrict cj = &s->cells_top[cjd];
 
-    /* Recursively perform FOF search on all other cells in top-level grid. */
-    rec_fof_search(c, cid, s, pid, &num_groups, dim, search_r2);
+      rec_fof_search_ci_cj(c, cj, s, pid, &num_groups, dim, search_r2);
+    }   
   }
 
   /* Calculate the total number of particles in each group. */