diff --git a/examples/SmoothedMetallicity/makeIC.py b/examples/SmoothedMetallicity/makeIC.py
index d076d1c15df24c6c2f12d5f87ed9ff7952357e88..86679d5efe897b9dfae7db94b36d74bb047661e6 100644
--- a/examples/SmoothedMetallicity/makeIC.py
+++ b/examples/SmoothedMetallicity/makeIC.py
@@ -26,12 +26,16 @@ import numpy as np
 gamma = 5./3.      # Gas adiabatic index
 rho0 = 1.          # Background density
 P0 = 1.e-6         # Background pressure
+Nelem = 9          # Gear: 9, EAGLE: 9
 low_metal = -6     # Low iron fraction
 high_metal = -5    # high iron fraction
 sigma_metal = 0.1  # relative standard deviation for the metallicities
-Nelem = 9          # Gear: 9, EAGLE: 9
 fileName = "smoothed_metallicity.hdf5"
 
+# shift all metals in order to obtain nicer plots
+low_metal = [low_metal] * Nelem + np.linspace(0, 3, Nelem)
+high_metal = [high_metal] * Nelem + np.linspace(0, 3, Nelem)
+
 # ---------------------------------------------------
 glass = h5py.File("glassCube_32.hdf5", "r")
 
@@ -57,13 +61,11 @@ u[:] = P0 / (rho0 * (gamma - 1))
 # set metallicities
 select = pos[:, 0] < 0.5
 nber = sum(select)
-sigma = abs(sigma_metal*low_metal)
-Z[select, :] = low_metal + np.random.normal(loc=0., scale=sigma,
-                                            size=(nber, Nelem))
+Z[select, :] = low_metal * (1 + np.random.normal(loc=0., scale=sigma_metal,
+                                                 size=(nber, Nelem)))
 nber = numPart - nber
-sigma = abs(sigma_metal*high_metal)
-Z[np.logical_not(select), :] = high_metal + np.random.normal(
-    loc=0., scale=sigma, size=(nber, Nelem))
+Z[np.logical_not(select), :] = high_metal * (1 + np.random.normal(
+    loc=0., scale=sigma_metal, size=(nber, Nelem)))
 
 # --------------------------------------------------
 
diff --git a/examples/SmoothedMetallicity/plotSolution.py b/examples/SmoothedMetallicity/plotSolution.py
index 183056dfd2b965075a0f4bf9540f6b14490b74b3..e5bca3dfb7fe1e43c836733894c9e297cdd468ca 100644
--- a/examples/SmoothedMetallicity/plotSolution.py
+++ b/examples/SmoothedMetallicity/plotSolution.py
@@ -29,17 +29,19 @@ matplotlib.use("Agg")
 import matplotlib.pyplot as plt
 
 # Parameters
-high_metal = -5    # High metal abundance
 low_metal = -6     # low metal abundance
+high_metal = -5    # High metal abundance
 sigma_metal = 0.1  # relative standard deviation for Z
 
+Nelem = 9
+# shift all metals in order to obtain nicer plots
+low_metal = [low_metal] * Nelem + np.linspace(0, 3, Nelem)
+high_metal = [high_metal] * Nelem + np.linspace(0, 3, Nelem)
+
 # ---------------------------------------------------------------
 # Don't touch anything after this.
 # ---------------------------------------------------------------
 
-Nelem = 4           # Number of element
-
-
 # Plot parameters
 params = {
     'axes.labelsize': 10,
@@ -81,10 +83,16 @@ git = sim["Code"].attrs["Git Revision"]
 
 pos = sim["/PartType0/Coordinates"][:, :]
 d = pos[:, 0] - boxSize / 2
-metal = sim["/PartType0/SmoothedElementAbundance"][:, :]
+smooth_metal = sim["/PartType0/SmoothedElementAbundance"][:, :]
+metal = sim["/PartType0/ElementAbundance"][:, :]
 h = sim["/PartType0/SmoothingLength"][:]
 h = np.mean(h)
 
+if (Nelem != metal.shape[1]):
+    print("Unexpected number of element, please check makeIC.py"
+          " and plotSolution.py")
+    exit(1)
+
 N = 1000
 d_a = np.linspace(-boxSize / 2., boxSize / 2., N)
 
@@ -125,14 +133,14 @@ def calc_a(d, high_metal, low_metal, std_dev, h):
     for i in range(Nelem):
         # compute low metallicity side
         s = d < -h
-        a[s, i] = low_metal
+        a[s, i] = low_metal[i]
         # compute high metallicity side
         s = d > h
-        a[s, i] = high_metal
+        a[s, i] = high_metal[i]
 
         # compute non constant parts
-        m = (high_metal - low_metal) / (2.0 * h)
-        c = (high_metal + low_metal) / 2.0
+        m = (high_metal[i] - low_metal[i]) / (2.0 * h)
+        c = (high_metal[i] + low_metal[i]) / 2.0
         # compute left linear part
         s = d < - boxSize / 2.0 + h
         a[s, i] = - m * (d[s] + boxSize / 2.0) + c
@@ -155,43 +163,32 @@ sol, sig = calc_a(d_a, high_metal, low_metal, sigma_metal, h)
 plt.figure()
 
 # Metallicity --------------------------------
-e = 0
 plt.subplot(221)
-plt.plot(d, metal[:, e], '.', color='r', ms=0.5, alpha=0.2)
-plt.plot(d_a, sol[:, e], '--', color='b', alpha=0.8, lw=1.2)
-plt.fill_between(d_a, sig[:, e, 0], sig[:, e, 1], facecolor="b",
-                 interpolate=True, alpha=0.5)
-plt.xlabel("${\\rm{Distance}}~r$", labelpad=0)
-plt.ylabel("${\\rm{Metallicity}}~Z$", labelpad=0)
-plt.xlim(-0.5, 0.5)
-plt.ylim(low_metal-1, high_metal+1)
+for e in range(Nelem):
+    plt.plot(metal[:, e], smooth_metal[:, e], '.', ms=0.5, alpha=0.2)
 
-# Metallicity --------------------------------
-e = 1
-plt.subplot(222)
-plt.plot(d, metal[:, e], '.', color='r', ms=0.5, alpha=0.2)
-plt.plot(d_a, sol[:, e], '--', color='b', alpha=0.8, lw=1.2)
-plt.fill_between(d_a, sig[:, e, 0], sig[:, e, 1], facecolor="b",
-                 interpolate=True, alpha=0.5)
-plt.xlabel("${\\rm{Distance}}~r$", labelpad=0)
-plt.ylabel("${\\rm{Metallicity}}~Z$", labelpad=0)
-plt.xlim(-0.5, 0.5)
-plt.ylim(low_metal-1, high_metal+1)
+xmin, xmax = metal.min(), metal.max()
+ymin, ymax = smooth_metal.min(), smooth_metal.max()
+x = max(xmin, ymin)
+y = min(xmax, ymax)
+plt.plot([x, y], [x, y], "--k", lw=1.0)
+plt.xlabel("${\\rm{Metallicity}}~Z_\\textrm{part}$", labelpad=0)
+plt.ylabel("${\\rm{Smoothed~Metallicity}}~Z_\\textrm{sm}$", labelpad=0)
 
 # Metallicity --------------------------------
-e = 2
+e = 0
 plt.subplot(223)
-plt.plot(d, metal[:, e], '.', color='r', ms=0.5, alpha=0.2)
+plt.plot(d, smooth_metal[:, e], '.', color='r', ms=0.5, alpha=0.2)
 plt.plot(d_a, sol[:, e], '--', color='b', alpha=0.8, lw=1.2)
 plt.fill_between(d_a, sig[:, e, 0], sig[:, e, 1], facecolor="b",
                  interpolate=True, alpha=0.5)
 plt.xlabel("${\\rm{Distance}}~r$", labelpad=0)
-plt.ylabel("${\\rm{Metallicity}}~Z$", labelpad=0)
+plt.ylabel("${\\rm{Smoothed~Metallicity}}~Z_\\textrm{sm}$", labelpad=0)
 plt.xlim(-0.5, 0.5)
-plt.ylim(low_metal-1, high_metal+1)
+plt.ylim(low_metal[e]-1, high_metal[e]+1)
 
 # Information -------------------------------------
-plt.subplot(224, frameon=False)
+plt.subplot(222, frameon=False)
 
 plt.text(-0.49, 0.9, "Smoothed Metallicity in 3D at $t=%.2f$" % time,
          fontsize=10)
@@ -207,4 +204,5 @@ plt.ylim(0, 1)
 plt.xticks([])
 plt.yticks([])
 
+plt.tight_layout()
 plt.savefig("SmoothedMetallicity.png", dpi=200)