diff --git a/src/fof.c b/src/fof.c
index d22592cfcdba8799fd74695a747be991873f9a39..15a0012fdc1f97be3933ddc549edb9a40a412af3 100644
--- a/src/fof.c
+++ b/src/fof.c
@@ -258,6 +258,9 @@ void fof_search_cell(struct space *s, struct cell *c, int *pid,
     const double piz = pi->x[2];
     const size_t offset_i = pi->offset;
 
+    /* Find the root of pi. */
+    int root_i = fof_find(offset_i, pid);
+
     for (size_t j = i + 1; j < count; j++) {
 
       struct gpart *pj = &gparts[j];
@@ -278,8 +281,7 @@ void fof_search_cell(struct space *s, struct cell *c, int *pid,
         r2 += dx[k] * dx[k];
       }
 
-      /* Find the roots of pi and pj. */
-      const int root_i = fof_find(offset_i, pid);
+      /* Find the root of pj. */
       const int root_j = fof_find(offset_j, pid);
 
       /* Skip particles in the same group. */
@@ -288,8 +290,13 @@ void fof_search_cell(struct space *s, struct cell *c, int *pid,
       /* Hit or miss? */
       if (r2 < l_x2) {
 
-        if (root_j < root_i)
+        /* If the root ID of pj is lower than pi's root ID set pi's root to point to pj's. 
+         * Otherwise set pj's to root to point to pi's.*/
+        if (root_j < root_i) {
           pid[root_i] = root_j;
+          /* Update root_i on the fly. */
+          root_i = root_j;
+        }
         else
           pid[root_j] = root_i;
 
@@ -336,13 +343,15 @@ void fof_search_pair_cells(struct space *s, struct cell *ci, struct cell *cj,
     const double piz = pi->x[2] - shift[2];
     const size_t offset_i = pi->offset;
 
+    /* Find the root of pi. */
+    int root_i = fof_find(offset_i, pid);
+    
     for (size_t j = 0; j < count_j; j++) {
 
       struct gpart *pj = &gparts_j[j];
       const size_t offset_j = pj->offset;
 
-      /* Find the roots of pi and pj. */
-      const int root_i = fof_find(offset_i, pid);
+      /* Find the root of pj. */
       const int root_j = fof_find(offset_j, pid);
 
       /* Skip particles in the same group. */
@@ -364,8 +373,13 @@ void fof_search_pair_cells(struct space *s, struct cell *ci, struct cell *cj,
       /* Hit or miss? */
       if (r2 < l_x2) {
 
-        if (root_j < root_i)
+        /* If the root ID of pj is lower than pi's root ID set pi's root to point to pj's. 
+         * Otherwise set pj's to root to point to pi's.*/
+        if (root_j < root_i) {
           pid[root_i] = root_j;
+          /* Update root_i on the fly. */
+          root_i = root_j;
+        }
         else
           pid[root_j] = root_i;