diff --git a/src/cell.c b/src/cell.c
index e5812e0603923db3d7b2fc9b498c5b650384f768..a0ee50855ef822115313b00f32dd81fc3d504580 100644
--- a/src/cell.c
+++ b/src/cell.c
@@ -2343,7 +2343,6 @@ int cell_has_tasks(struct cell *c) {
 void cell_drift_part(struct cell *c, const struct engine *e, int force) {
 
   const float hydro_h_max = e->hydro_properties->h_max;
-  const double timeBase = e->timeBase;
   const integertime_t ti_old_part = c->ti_old_part;
   const integertime_t ti_current = e->ti_current;
   struct part *const parts = c->parts;
@@ -2384,7 +2383,19 @@ void cell_drift_part(struct cell *c, const struct engine *e, int force) {
   } else if (!c->split && force && ti_current > ti_old_part) {
 
     /* Drift from the last time the cell was drifted to the current time */
-    const double dt = (ti_current - ti_old_part) * timeBase;
+    double dt_drift, dt_kick_grav, dt_kick_hydro;
+    if (e->policy & engine_policy_cosmology) {
+      dt_drift =
+          cosmology_get_drift_factor(e->cosmology, ti_old_part, ti_current);
+      dt_kick_grav =
+          cosmology_get_grav_kick_factor(e->cosmology, ti_old_part, ti_current);
+      dt_kick_hydro = cosmology_get_hydro_kick_factor(e->cosmology, ti_old_part,
+                                                      ti_current);
+    } else {
+      dt_drift = (ti_current - ti_old_part) * e->time_base;
+      dt_kick_grav = (ti_current - ti_old_part) * e->time_base;
+      dt_kick_hydro = (ti_current - ti_old_part) * e->time_base;
+    }
 
     /* Loop over all the gas particles in the cell */
     const size_t nr_parts = c->count;
@@ -2395,7 +2406,8 @@ void cell_drift_part(struct cell *c, const struct engine *e, int force) {
       struct xpart *const xp = &xparts[k];
 
       /* Drift... */
-      drift_part(p, xp, dt, timeBase, ti_old_part, ti_current);
+      drift_part(p, xp, dt_drift, dt_kick_hydro, dt_kick_grav, ti_old_part,
+                 ti_current);
 
       /* Limit h to within the allowed range */
       p->h = min(p->h, hydro_h_max);
diff --git a/src/drift.h b/src/drift.h
index 6aa891b1f40a0e8d0cf35163655f1d1fcc50a14a..058e668cab3e598e3620b83933ccdc91218a1c22 100644
--- a/src/drift.h
+++ b/src/drift.h
@@ -68,14 +68,16 @@ __attribute__((always_inline)) INLINE static void drift_gpart(
  *
  * @param p The #part to drift.
  * @param xp The #xpart of the particle.
- * @param dt The drift time-step
- * @param timeBase The minimal allowed time-step size.
- * @param ti_old Integer start of time-step
- * @param ti_current Integer end of time-step
+ * @param dt_drift The drift time-step
+ * @param dt_kick_grav The kick time-step for gravity accelerations.
+ * @param dt_kick_hydro The kick time-step for hydro accelerations.
+ * @param ti_old Integer start of time-step (for debugging checks).
+ * @param ti_current Integer end of time-step (for debugging checks).
  */
 __attribute__((always_inline)) INLINE static void drift_part(
-    struct part *restrict p, struct xpart *restrict xp, double dt,
-    double timeBase, integertime_t ti_old, integertime_t ti_current) {
+    struct part *restrict p, struct xpart *restrict xp, double dt_drift,
+    double dt_kick_hydro, double dt_kick_grav, integertime_t ti_old,
+    integertime_t ti_current) {
 
 #ifdef SWIFT_DEBUG_CHECKS
   if (p->ti_drift != ti_old)
@@ -88,21 +90,25 @@ __attribute__((always_inline)) INLINE static void drift_part(
 #endif
 
   /* Drift... */
-  p->x[0] += xp->v_full[0] * dt;
-  p->x[1] += xp->v_full[1] * dt;
-  p->x[2] += xp->v_full[2] * dt;
+  p->x[0] += xp->v_full[0] * dt_drift;
+  p->x[1] += xp->v_full[1] * dt_drift;
+  p->x[2] += xp->v_full[2] * dt_drift;
 
   /* Predict velocities (for hydro terms) */
-  p->v[0] += p->a_hydro[0] * dt;
-  p->v[1] += p->a_hydro[1] * dt;
-  p->v[2] += p->a_hydro[2] * dt;
+  p->v[0] += p->a_hydro[0] * dt_kick_hydro;
+  p->v[1] += p->a_hydro[1] * dt_kick_hydro;
+  p->v[2] += p->a_hydro[2] * dt_kick_hydro;
+
+  p->v[0] += xp->a_grav[0] * dt_kick_grav;
+  p->v[1] += xp->a_grav[1] * dt_kick_grav;
+  p->v[2] += xp->a_grav[2] * dt_kick_grav;
 
   /* Predict the values of the extra fields */
-  hydro_predict_extra(p, xp, dt);
+  hydro_predict_extra(p, xp, dt_drift);
 
   /* Compute offsets since last cell construction */
   for (int k = 0; k < 3; k++) {
-    const float dx = xp->v_full[k] * dt;
+    const float dx = xp->v_full[k] * dt_drift;
     xp->x_diff[k] -= dx;
     xp->x_diff_sort[k] -= dx;
   }
diff --git a/src/hydro/Gadget2/hydro.h b/src/hydro/Gadget2/hydro.h
index 3b747a494c6436b8e11f8d4503f802751c044277..7af47b48924c4c013a3f4dd571838755a39861e3 100644
--- a/src/hydro/Gadget2/hydro.h
+++ b/src/hydro/Gadget2/hydro.h
@@ -513,6 +513,9 @@ __attribute__((always_inline)) INLINE static void hydro_first_init_part(
   xp->v_full[0] = p->v[0];
   xp->v_full[1] = p->v[1];
   xp->v_full[2] = p->v[2];
+  xp->a_grav[0] = 0.f;
+  xp->a_grav[1] = 0.f;
+  xp->a_grav[2] = 0.f;
   xp->entropy_full = p->entropy;
 
   hydro_reset_acceleration(p);
diff --git a/src/hydro/Gadget2/hydro_part.h b/src/hydro/Gadget2/hydro_part.h
index 55745b52f4459905b02adf4824c489510ee93f97..90f73571701b37b3377601655330d8d25f862a05 100644
--- a/src/hydro/Gadget2/hydro_part.h
+++ b/src/hydro/Gadget2/hydro_part.h
@@ -46,6 +46,9 @@ struct xpart {
   /* Velocity at the last full step. */
   float v_full[3];
 
+  /* Gravitational acceleration at the last full step. */
+  float a_grav[3];
+
   /* Entropy at the last full step. */
   float entropy_full;