diff --git a/README b/Readme.md
similarity index 100%
rename from README
rename to Readme.md
diff --git a/test/cooling.yml b/examples/cooling_rate/cooling.yml
similarity index 80%
rename from test/cooling.yml
rename to examples/cooling_rate/cooling.yml
index 96555e68c238897b81b856a624468385f0f9af90..54c4cf3c2ecdfe33e16cb28ebe16243206ecffac 100644
--- a/test/cooling.yml
+++ b/examples/cooling_rate/cooling.yml
@@ -37,10 +37,16 @@ InitialConditions:
 
 # Cooling with Grackle 2.0
 GrackleCooling:
-  GrackleCloudyTable: CloudyData_UVB=HM2012.h5 # Name of the Cloudy Table
-  UVbackground: 1 # Enable or not the UV background
-  GrackleRedshift: 0 # Redshift to use (-1 means time based redshift)
-  GrackleHSShieldingDensityThreshold: 1.1708e-26 # self shielding threshold in internal system of units
+  CloudyTable: CloudyData_UVB=HM2012.h5 # Name of the Cloudy Table
+  WithUVbackground: 1 # Enable or not the UV background
+  Redshift: 0 # Redshift to use (-1 means time based redshift)
+  WithMetalCooling: 1 # Enable or not the metal cooling
+  ProvideVolumetricHeatingRates: 0 # User provide volumetric heating rates
+  ProvideSpecificHeatingRates: 0 # User provide specific heating rates
+  SelfShieldingMethod: 0 # Grackle (<= 3) or Gear self shielding method
+  OutputMode: 0 # Write in output corresponding primordial chemistry mode
+  ConvergenceLimit: 1e-3
+  MaxSteps: 100000
 
 Gravity:
   eta:          0.025    # Constant dimensionless multiplier for time integration.
diff --git a/examples/cooling_rate/generate_grackle_data.py b/examples/cooling_rate/generate_grackle_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ebb364bf9d227113b781cbf068ea2c86290219f
--- /dev/null
+++ b/examples/cooling_rate/generate_grackle_data.py
@@ -0,0 +1,190 @@
+#!/usr/bin/env python
+########################################################################
+#
+# Cooling rate data generation. This code is a modified version
+# of a code developed by the Grackle Team.
+#
+#
+# Copyright (c) 2013-2016, Grackle Development Team.
+#
+# Distributed under the terms of the Enzo Public Licence.
+#
+# The full license is in the file LICENSE, distributed with this
+# software.
+########################################################################
+
+"""
+Generate (or update) a hdf5 file containing
+the grackle data for each primordial chemistry
+"""
+
+import numpy as np
+from h5py import File
+from copy import deepcopy
+
+from pygrackle import \
+    chemistry_data, \
+    setup_fluid_container
+
+from pygrackle.utilities.physical_constants import \
+    mass_hydrogen_cgs, \
+    sec_per_Myr
+
+filename = "grackle.hdf5"
+
+debug = 1
+
+
+def generate_data(primordial_chemistry):
+    current_redshift = 0.
+
+    # Set solver parameters
+    my_chemistry = chemistry_data()
+    my_chemistry.use_grackle = 1
+    my_chemistry.with_radiative_cooling = 1
+    my_chemistry.primordial_chemistry = primordial_chemistry
+    my_chemistry.metal_cooling = 1
+    my_chemistry.UVbackground = 1
+    my_chemistry.self_shielding_method = 0
+    my_chemistry.H2_self_shielding = 0
+    my_chemistry.grackle_data_file = "CloudyData_UVB=HM2012.h5"
+
+    # Set units
+    my_chemistry.comoving_coordinates = 0  # proper units
+    my_chemistry.a_units = 1.0
+    my_chemistry.a_value = 1.0 / (1.0 + current_redshift) / \
+        my_chemistry.a_units
+    my_chemistry.length_units = 3.085e21
+    my_chemistry.density_units = 1.989e43 / my_chemistry.length_units**3
+    my_chemistry.velocity_units = 20725573.785998672
+    my_chemistry.time_units = my_chemistry.length_units / \
+        my_chemistry.velocity_units
+
+    density = mass_hydrogen_cgs
+    dt = 1e-8 * sec_per_Myr / my_chemistry.time_units
+    # Call convenience function for setting up a fluid container.
+    # This container holds the solver parameters, units, and fields.
+    temperature = np.logspace(1, 9, 1000)
+    fc = setup_fluid_container(my_chemistry,
+                               metal_mass_fraction=0.01295,
+                               temperature=temperature,
+                               density=density,
+                               converge=True,
+                               tolerance=1e-5,
+                               max_iterations=1e8)
+
+    f, data = write_input(
+        fc, my_chemistry, temperature, dt, primordial_chemistry)
+
+    old = deepcopy(fc)
+    fc.solve_chemistry(dt)
+    rate = (fc["energy"] - old["energy"]) / dt
+
+    write_output(f, data, fc, rate)
+
+
+def write_output(f, data, fc, rate):
+    data.create_group("Output")
+    tmp = data["Output"]
+    # Rate
+    dset = tmp.create_dataset("Rate", rate.shape,
+                              dtype=rate.dtype)
+    dset[:] = rate
+    # Energy
+    dset = tmp.create_dataset("Energy", fc["energy"].shape,
+                              dtype=fc["energy"].dtype)
+    dset[:] = fc["energy"]
+
+    f.close()
+
+
+def write_input(fc, my_chemistry, temperature, dt, primordial_chemistry):
+    f = File(filename, "a")
+    data_name = "PrimordialChemistry%i" % primordial_chemistry
+    if data_name in f:
+        print("Updating Dataset")
+        data = f[data_name]
+        f.pop(data_name)
+    else:
+        print("Creating Dataset")
+
+    data = f.create_group(data_name)
+
+    # Units
+    data.create_group("Units")
+    tmp = data["Units"]
+    tmp.attrs["Length"] = my_chemistry.length_units
+    tmp.attrs["Density"] = my_chemistry.density_units
+    tmp.attrs["Velocity"] = my_chemistry.velocity_units
+    tmp.attrs["Time"] = my_chemistry.time_units
+
+    # Parameters
+    data.create_group("Params")
+    tmp = data["Params"]
+    tmp.attrs["MetalCooling"] = my_chemistry.metal_cooling
+    tmp.attrs["UVBackground"] = my_chemistry.UVbackground
+    tmp.attrs["SelfShieldingMethod"] = my_chemistry.self_shielding_method
+    tmp.attrs["TimeStep"] = dt
+
+    # Inputs
+    data.create_group("Input")
+    tmp = data["Input"]
+    # energy
+    dset = tmp.create_dataset("Energy", fc["energy"].shape,
+                              dtype=fc["energy"].dtype)
+    dset[:] = fc["energy"]
+    # density
+    dset = tmp.create_dataset("Density", fc["density"].shape,
+                              dtype=fc["density"].dtype)
+    dset[:] = fc["density"]
+    # Temperature
+    dset = tmp.create_dataset("Temperature", temperature.shape,
+                              dtype=temperature.dtype)
+    dset[:] = temperature
+
+    write_fractions(tmp, fc, primordial_chemistry)
+
+    return f, data
+
+
+def write_fractions(tmp, fc, primordial_chemistry):
+    fields = get_fields(primordial_chemistry)
+    for i in fields:
+        dset = tmp.create_dataset(i, fc[i].shape,
+                                  dtype=fc[i].dtype)
+        dset[:] = fc[i] / fc["density"]
+
+
+def get_fields(primordial_chemistry):
+    fields = [
+        "metal"
+    ]
+    if primordial_chemistry > 0:
+        fields.extend([
+            "HI",
+            "HII",
+            "HeI",
+            "HeII",
+            "HeIII",
+            "de"])
+
+    if primordial_chemistry > 1:
+        fields.extend([
+            "HM",
+            "H2I",
+            "H2II"
+            ])
+
+    if primordial_chemistry > 2:
+        fields.extend([
+            "DI",
+            "DII",
+            "HDI"
+        ])
+    return fields
+
+
+if __name__ == "__main__":
+    for i in range(4):
+        print("Computing Primordial Chemistry %i" % i)
+        generate_data(i)
diff --git a/examples/cooling_rate/plot_cooling.py b/examples/cooling_rate/plot_cooling.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cf9aae0de8a26b35591506e3ab289d17933c53a
--- /dev/null
+++ b/examples/cooling_rate/plot_cooling.py
@@ -0,0 +1,289 @@
+#!/usr/bin/env python3
+
+from pyswiftsim import wrapper
+
+from copy import deepcopy
+import numpy as np
+import matplotlib.pyplot as plt
+from astropy import units
+from os.path import isfile
+from h5py import File
+
+plt.rc('text', usetex=True)
+
+# PARAMETERS
+
+# grackle primordial chemistry
+primordial_chemistry = 1
+
+# reference data
+grackle_filename = "grackle.hdf5"
+compute_equilibrium = True
+
+# swift params filename
+filename = "cooling.yml"
+
+# if grackle_filename does not exist
+# use following values
+
+# density in atom / cm3
+N_rho = 1
+# with N_rho > 1, the code is not implemented to deal
+# with the reference
+if N_rho == 1:
+    default_density = np.array([1.])
+else:
+    default_density = np.logspace(-3, 1, N_rho)
+
+# temperature in K
+N_T = 10
+default_temperature = np.logspace(1, 9, N_T)
+
+# time step in s
+default_dt = units.Myr * 1e-8
+default_dt = default_dt.to("s") / units.s
+
+# adiabatic index
+gamma = 5. / 3.
+
+# SCRIPT
+
+
+def get_fields(primordial_chemistry):
+    fields = [
+        "metal"
+    ]
+    if primordial_chemistry > 0:
+        fields.extend([
+            "HI",
+            "HII",
+            "HeI",
+            "HeII",
+            "HeIII",
+            "de"])
+
+    if primordial_chemistry > 1:
+        fields.extend([
+            "HM",
+            "H2I",
+            "H2II"
+            ])
+
+    if primordial_chemistry > 2:
+        fields.extend([
+            "DI",
+            "DII",
+            "HDI"
+        ])
+    return fields
+
+
+def generate_default_initial_condition(us, pconst):
+    print("Generating default initial conditions")
+    d = {}
+    # generate grid
+    rho, T = np.meshgrid(default_density, default_temperature)
+    rho = deepcopy(rho.reshape(-1))
+    T = T.reshape(-1)
+    d["temperature"] = T
+
+    # Deal with units
+    rho *= us.UnitLength_in_cgs**3 * pconst.const_proton_mass
+    d["density"] = rho
+
+    energy = pconst.const_boltzmann_k * T / us.UnitTemperature_in_cgs
+    energy /= (gamma - 1.) * pconst.const_proton_mass
+    d["energy"] = energy
+
+    dt = default_dt / us.UnitTime_in_cgs
+    d["time_step"] = dt
+
+    return d
+
+
+def read_grackle_data(filename, us, primordial_chemistry):
+    print("Reading initial conditions")
+    f = File(filename, "r")
+    data = f["PrimordialChemistry%i" % primordial_chemistry]
+
+    # read units
+    tmp = data["Units"].attrs
+
+    u_len = tmp["Length"] / us.UnitLength_in_cgs
+    u_den = tmp["Density"] * us.UnitLength_in_cgs**3 / us.UnitMass_in_cgs
+    u_time = tmp["Time"] / us.UnitTime_in_cgs
+
+    # read input
+    tmp = data["Input"]
+
+    energy = tmp["Energy"][:] * u_len**2 / u_time**2
+    d["energy"] = energy
+
+    T = tmp["Temperature"][:] / us.UnitTemperature_in_cgs
+    d["temperature"] = T
+
+    density = tmp["Density"][:] * u_den
+    d["density"] = density
+
+    dt = data["Params"].attrs["TimeStep"] * u_time
+    d["time_step"] = dt
+
+    # read fractions
+    for i in get_fields(primordial_chemistry):
+        d[i] = tmp[i][:]
+
+    # read output
+    tmp = data["Output"]
+
+    energy = tmp["Energy"][:] * u_len**2 / u_time**2
+    d["out_energy"] = energy
+
+    rate = tmp["Rate"][:] * u_len**2 / (u_time**3)
+    d["rate"] = rate
+
+    f.close()
+    return d
+
+
+def initialize_swift(filename):
+    print("Initialization of SWIFT")
+    d = {}
+
+    # parse swift params
+    params = wrapper.parserReadFile(filename)
+    d["params"] = params
+
+    # init units
+    us, pconst = wrapper.unitSystemInit(params)
+    d["us"] = us
+    d["pconst"] = pconst
+
+    # init cooling
+    cooling = wrapper.coolingInit(params, us, pconst)
+    d["cooling"] = cooling
+    return d
+
+
+def plot_solution(rate, data, us):
+    print("Plotting solution")
+    energy = data["energy"]
+    rho = data["density"]
+    T = data["temperature"]
+
+    ref = False
+    if "rate" in data:
+        ref = True
+        grackle_rate = data["rate"]
+
+    # change units => cgs
+    rho *= us.UnitMass_in_cgs / us.UnitLength_in_cgs**3
+
+    T *= us.UnitTemperature_in_cgs
+
+    energy *= us.UnitLength_in_cgs**2 / us.UnitTime_in_cgs**2
+
+    rate *= us.UnitLength_in_cgs**2 / us.UnitTime_in_cgs**3
+
+    if ref:
+        grackle_rate *= us.UnitLength_in_cgs**2 / us.UnitTime_in_cgs**3
+
+    # lambda cooling
+    lam = rate * rho
+
+    if ref:
+        grackle_lam = grackle_rate * rho
+
+    # do plot
+    if N_rho == 1:
+        # plot Lambda vs T
+        plt.figure()
+        plt.loglog(T, np.abs(lam),
+                   label="SWIFT, %s" % wrapper.configGetCooling())
+        if ref:
+            # plot reference
+            plt.loglog(T, np.abs(grackle_lam),
+                       label="Grackle, Prim. Chem. %i" % primordial_chemistry)
+            plt.legend()
+            plt.xlabel("Temperature [K]")
+            plt.ylabel("$\\Lambda$ [erg s$^{-1}$ cm$^{3}$]")
+
+            # plot error vs T
+            plt.figure()
+            plt.plot(T, (lam - grackle_lam) / grackle_lam)
+            plt.gca().set_xscale("log")
+            plt.xlabel("Temperature [K]")
+            plt.ylabel(
+                r"$(\Lambda - \Lambda_\textrm{ref}) / \Lambda_\textrm{ref}$")
+
+        plt.show()
+
+    else:
+        shape = [N_rho, N_T]
+        cooling_time = energy / rate
+        cooling_length = np.sqrt(gamma * (gamma-1.) * energy) * cooling_time
+
+        cooling_length = np.log10(np.abs(cooling_length) / units.kpc.to('cm'))
+
+        # reshape
+        rho = rho.reshape(shape)
+        T = T.reshape(shape)
+        energy = energy.reshape(shape)
+        cooling_length = cooling_length.reshape(shape)
+
+        _min = -7
+        _max = 7
+        N_levels = 100
+        levels = np.linspace(_min, _max, N_levels)
+        plt.figure()
+        plt.contourf(rho, T, cooling_length, levels)
+        plt.xlabel("Density [atom/cm3]")
+        plt.ylabel("Temperature [K]")
+
+        ax = plt.gca()
+        ax.set_xscale("log")
+        ax.set_yscale("log")
+
+        cbar = plt.colorbar()
+        tc = np.arange(_min, _max, 1.5)
+        cbar.set_ticks(tc)
+        cbar.set_ticklabels(tc)
+
+        plt.show()
+
+
+if __name__ == "__main__":
+
+    d = initialize_swift(filename)
+    pconst = d["pconst"]
+    us = d["us"]
+    params = d["params"]
+    cooling = d["cooling"]
+
+    if isfile(grackle_filename):
+        d = read_grackle_data(grackle_filename, us, primordial_chemistry)
+    else:
+        d = generate_default_initial_condition(us, pconst)
+
+    # du / dt
+    print("Computing cooling...")
+    rate = np.zeros(d["density"].shape)
+
+    if compute_equilibrium:
+        rate = wrapper.coolingRate(pconst, us, cooling,
+                                   d["density"].astype(np.float32),
+                                   d["energy"].astype(np.float32),
+                                   d["time_step"])
+    else:
+        fields = get_fields(primordial_chemistry)
+        N = len(fields)
+        frac = np.zeros([len(d["density"]), N])
+        for i in range(N):
+            frac[:, i] = d[fields[i]]
+
+        rate = wrapper.coolingRate(pconst, us, cooling,
+                                   d["density"].astype(np.float32),
+                                   d["energy"].astype(np.float32),
+                                   d["time_step"],
+                                   frac.astype(np.float32))
+
+    plot_solution(rate, d, us)
diff --git a/examples/cooling_rate/run.sh b/examples/cooling_rate/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3564de07f4f0751a51f025fc018bf5ed5c07e64c
--- /dev/null
+++ b/examples/cooling_rate/run.sh
@@ -0,0 +1,17 @@
+#!/bin/bash
+
+# Get the Grackle cooling table
+if [ ! -e CloudyData_UVB=HM2012.h5 ]
+then
+    echo "Fetching the Cloudy tables required by Grackle..."
+    ./getCoolingTable.sh
+fi
+
+# Generate Grackle data if not present
+if [ ! -e grackle.hdf5 ]
+then
+    echo "Generating Grackle Data..."
+    ./generate_grackle_data.py
+fi
+
+./plot_cooling.py
diff --git a/pyswiftsim/structure.py b/pyswiftsim/structure.py
index 2e89c0c5e6a3427d0ec38b0ede0d8903160e4376..78a35b781e6062e2f133bbcdccb5d726ae67b440 100644
--- a/pyswiftsim/structure.py
+++ b/pyswiftsim/structure.py
@@ -1,13 +1,13 @@
 from pyswiftsim import wrapper
 
 import struct
-import numpy
 from ctypes import *
 
 PARSER_MAX_LINE_SIZE = 256
 PARSER_MAX_NO_OF_PARAMS = 256
 PARSER_MAX_NO_OF_SECTIONS = 64
 
+
 ######################################################################
 #                                                                    #
 #                         SwiftStruct                                #
@@ -33,10 +33,10 @@ class SwiftStruct(struct.Struct):
     """
     def __init__(self, struct_format, data, parent):
         super().__init__(struct_format)
-        self.parent = parent # parent for ArrayStruct
+        self.parent = parent  # parent for ArrayStruct
 
         if isinstance(data, bytes):
-            self.data = data # bytes string data
+            self.data = data  # bytes string data
         elif isinstance(data, dict):
             tmp = []
             for i in self.struct_name:
@@ -46,7 +46,7 @@ class SwiftStruct(struct.Struct):
         else:
             raise ValueError("Data should be either bytes or dict, "
                              "received ", type(data))
-        
+
     @property
     def struct_name(self):
         """
@@ -71,7 +71,6 @@ class SwiftStruct(struct.Struct):
         """
         return {}
 
-
     def _getInfoFromName(self, name):
         """
         Compute index, format and size from an attribute name.
@@ -102,9 +101,8 @@ class SwiftStruct(struct.Struct):
 
         if n > 1:
             i = slice(i, i+n)
-        
+
         return i, form, n
-        
 
     @property
     def struct_size_format(self):
@@ -116,7 +114,7 @@ class SwiftStruct(struct.Struct):
 
         out_nber: list
             number of element for each attribute
-        
+
         out_form: list
             format for each attribute
         """
@@ -139,7 +137,7 @@ class SwiftStruct(struct.Struct):
 
             if v == "s":
                 count = 1
-                
+
             count = int(count)
             out_nber.append(count)
             out_form.append(v)
@@ -147,7 +145,6 @@ class SwiftStruct(struct.Struct):
 
         return out_nber, out_form
 
-
     def __str__(self, tab=""):
         txt = tab + "%s:\n" % type(self)
         for name in self.struct_name:
@@ -163,7 +160,7 @@ class SwiftStruct(struct.Struct):
         # case where the attribute is not in the structure
         if name not in self.struct_name:
             return object.__getattribute__(self, name)
-        
+
         # case where the attribute is in the structure
         else:
             i, form, n = self._getInfoFromName(name)
@@ -187,7 +184,7 @@ class SwiftStruct(struct.Struct):
                 # other case => array
                 else:
                     return data[i]
-                
+
             else:
                 # transform scalar -> vector
                 nform = str(n) + form
@@ -198,7 +195,6 @@ class SwiftStruct(struct.Struct):
                 data = struct.pack(nform, *data)
                 return ArrayStruct(nform, data, self, name)
 
-        
     def __setattr__(self, name, value):
         # case where the attribute is not in the structure
         if name not in self.struct_name:
@@ -225,7 +221,7 @@ class ArrayStruct(SwiftStruct):
     _name = [
         "array_data"
     ]
-        
+
     def __init__(self, struct_format, data, parent, name):
         super().__init__(struct_format, data, parent)
         self._format = struct_format
@@ -235,7 +231,7 @@ class ArrayStruct(SwiftStruct):
         data = self.unpack(self.data)
         data = self._clean(data)
         return data[ii]
-    
+
     def __setitem__(self, ii, value):
         data = list(self.unpack(self.data))
         data[ii] = value
@@ -254,8 +250,7 @@ class ArrayStruct(SwiftStruct):
         data = self.unpack(self.data)
         data = self._clean(data)
         return tab + str(data) + "\n"
-        
-        
+
     def getArray(self):
         return self.unpack(self.data)
 
@@ -263,7 +258,7 @@ class ArrayStruct(SwiftStruct):
     def struct_format(self):
         return self._format
 
-    
+
 ######################################################################
 #                                                                    #
 #                         UnitSystem                                 #
@@ -297,6 +292,7 @@ class ChemistryPartData(SwiftStruct):
     def __init__(self, data, parent=None):
         super().__init__(self.struct_format, data, parent)
 
+
 ######################################################################
 #                                                                    #
 #                          Part                                      #
@@ -326,8 +322,8 @@ class Part(SwiftStruct):
     def __init__(self, data, parent=None):
         super().__init__(self.struct_format, data, parent)
         print("ERROR, need to fix density/time_bin")
-        
-    
+
+
 ######################################################################
 #                                                                    #
 #                       Parameter                                    #
@@ -343,10 +339,10 @@ class Parameter(SwiftStruct):
             "value"
         ]
 
-
     def __init__(self, data, parent=None):
         super().__init__(self.struct_format, data, parent)
-    
+
+
 ######################################################################
 #                                                                    #
 #                       Section                                      #
@@ -363,6 +359,7 @@ class Section(SwiftStruct):
     def __init__(self, data, parent=None):
         super().__init__(self.struct_format, data, parent)
 
+
 ######################################################################
 #                                                                    #
 #                       SwiftParams                                  #
@@ -402,6 +399,7 @@ class SwiftParams(SwiftStruct):
             "data_params": param
         }
 
+
 ######################################################################
 #                                                                    #
 #                        PhysConst                                   #
@@ -428,13 +426,15 @@ class PhysConst(SwiftStruct):
         "const_earth_mass",
     ]
 
-
     def __init__(self, data, parent=None):
         super().__init__(self.struct_format, data, parent)
 
 
 class GrackleCodeUnits(SwiftStruct):
+    cooling_type = wrapper.configGetCooling()
     _format = "idddddd"
+    if cooling_type == "grackle_float":
+        _format = _format.replace("d", "f")
     _name = [
         "comoving_coordinates",
         "density_units",
@@ -447,9 +447,13 @@ class GrackleCodeUnits(SwiftStruct):
 
     def __init__(self, data, parent=None):
         super().__init__(self.struct_format, data, parent)
-    
+
+
 class GrackleChemistryData(SwiftStruct):
+    cooling_type = wrapper.configGetCooling()
     _format = "iiiiiPidiidiiiiiiidddiiddiddiiddddddiiiiii"
+    if cooling_type == "grackle_float":
+        _format = _format.replace("d", "f")
     _name = [
         'use_grackle',
         'with_radiative_cooling',
@@ -498,6 +502,7 @@ class GrackleChemistryData(SwiftStruct):
     def __init__(self, data, parent=None):
         super().__init__(self.struct_format, data, parent)
 
+
 class CoolingFunctionData(SwiftStruct):
     cooling_type = wrapper.configGetCooling()
     if cooling_type == "const_lambda":
@@ -509,7 +514,7 @@ class CoolingFunctionData(SwiftStruct):
             "min_energy",
             "cooling_tstep_mult"
         ]
-    elif cooling_type == "grackle":
+    elif "grackle" in cooling_type:
         _format = "200cidd{code_units}s{chemistry}s".format(
             code_units=struct.calcsize(GrackleCodeUnits._format),
             chemistry=struct.calcsize(GrackleChemistryData._format)
@@ -543,6 +548,5 @@ class CoolingFunctionData(SwiftStruct):
         raise ValueError(
             "Cooling Type %s not implemented" % cooling_type)
 
-
     def __init__(self, data, parent=None):
         super().__init__(self.struct_format, data, parent)
diff --git a/setup.py b/setup.py
index b87b13a808514f6cf666837b5570337f8a721e28..6f8712d727e90d1af93e599ed618d67ad68da860 100644
--- a/setup.py
+++ b/setup.py
@@ -1,28 +1,36 @@
 #!/usr/bin/env python3
 
-descr = """
-Wrapper around the SPH cosmological simulation code SWIFT
-"""
-
 from setuptools import setup, find_packages, Extension
 import sys
 import os
 from glob import glob
 import numpy
 
+descr = """
+Wrapper around the SPH cosmological simulation code SWIFT
+"""
+
+with_omp = True
 os.environ["CC"] = "mpicc"
 
-cflags = ["-Werror",
-          "-Wall",
-          "-Wextra",
-          # disables some warnings due to python
-          "-Wno-unused-parameter",
-          "-Wno-strict-prototypes",
-          "-Wno-unused-function",
-          "-Wno-incompatible-pointer-types",
-          "-Wno-missing-field-initializers",
+cflags = [
+    "-Werror",
+    "-Wall",
+    "-Wextra",
+    # disables some warnings due to python
+    "-Wno-unused-parameter",
+    "-Wno-strict-prototypes",
+    "-Wno-unused-function",
+    "-Wno-incompatible-pointer-types",
+    "-Wno-missing-field-initializers",
+    "-fopenmp"
 ]
 
+lflags = [
+    "-fopenmp"
+    ]
+
+
 # deal with arguments
 def parseCmdLine(arg, store=False):
     ret = False
@@ -37,11 +45,12 @@ def parseCmdLine(arg, store=False):
         sys.argv.remove(arg)
 
     return ret
-    
+
+
 swift_path = parseCmdLine("--with-swift", store=True)
 
 # python lib dependency
-install_requires=["numpy"]
+install_requires = ["numpy"]
 
 
 def getValueFromMakefile(swift_root, value):
@@ -51,7 +60,7 @@ def getValueFromMakefile(swift_root, value):
     with open(makefile, "r") as f:
         for line in f.readlines():
             if value == line[:N]:
-                return line[N:-1] # remove \n
+                return line[N:-1]  # remove \n
 
     raise ValueError("Value %s not found in Makefile" % value)
 
@@ -80,9 +89,10 @@ if swift_path:
     include.append(grackle_inc)
 
 # C libraries
-lib = ["m",
-       "swiftsim",
-       "hdf5",
+lib = [
+    "m",
+    "swiftsim",
+    "hdf5",
 ]
 
 lib_dir = []
@@ -90,7 +100,7 @@ lib_dir = []
 if swift_path:
     lib_dir.append(swift_path + "/src/.libs")
     lib_dir.append(hdf5_root + "/lib")
-       
+
 #  src files
 c_src = []
 
@@ -99,7 +109,7 @@ data_files = []
 
 
 ##############
-## C ext
+#  C ext
 ##############
 
 c_src = glob("src/*.c")
@@ -108,45 +118,41 @@ ext_modules = Extension("pyswiftsim.wrapper",
                         include_dirs=include,
                         libraries=lib,
                         library_dirs=lib_dir,
-                        extra_compile_args=cflags)
+                        extra_compile_args=cflags,
+                        extra_link_args=lflags)
 
 ext_modules = [ext_modules]
-    
+
 ##############
-## data
+#  data
 ##############
 
 data_files = []
 
 ##############
-## scripts
+#  scripts
 ##############
 
 list_scripts = []
 
 ##############
-## Setup
+#  Setup
 ##############
 
 setup(
-    name         = "pyswiftsim",
-    version      = "0.1",
-    author       = "Hausammann Loic",
-    author_email = "loic.hausammann@epfl.ch",
-    description  = descr,
-    license      = "GPLv3",
-    keywords     = "nbody sph simulation hpc",
-    url          = "",
-
-    packages         = find_packages(),
-
-    data_files       = data_files,
-
-    scripts          = list_scripts,
-
-    install_requires = install_requires,
-
-    dependency_links = dependency_links,
-    
-    ext_modules      = ext_modules,
+    name="pyswiftsim",
+    version="0.1",
+    author="Hausammann Loic",
+    author_email="loic.hausammann@epfl.ch",
+    description=descr,
+    license="GPLv3",
+    keywords="nbody sph simulation hpc",
+    url="",
+
+    packages=find_packages(),
+    data_files=data_files,
+    scripts=list_scripts,
+    install_requires=install_requires,
+    dependency_links=dependency_links,
+    ext_modules=ext_modules,
 )
diff --git a/src/config_wrapper.h b/src/config_wrapper.h
index 6b9b2a8e752bba7826b699d76050e946c57c4f1c..5a17aeb3c933dc9b36a7884bc11125c7a2b5c9ef 100644
--- a/src/config_wrapper.h
+++ b/src/config_wrapper.h
@@ -10,11 +10,25 @@
  */
 PyObject* config_get_cooling() {
   char *cooling_name;
+  /* lambda */
 #ifdef COOLING_CONST_LAMBDA
   cooling_name = "const_lambda";
+
+  /* grackle */
 #elif defined(COOLING_GRACKLE)
+#if COOLING_GRACKLE_MODE == 0
   cooling_name = "grackle";
-#endif
+#elif COOLING_GRACKLE_MODE == 1
+  cooling_name = "grackle1";
+#elif COOLING_GRACKLE_MODE == 2
+  cooling_name = "grackle2";  
+#elif COOLING_GRACKLE_MODE == 3
+  cooling_name = "grackle3";
+#else
+  error("Grackle mode unknown");
+#endif // COOLING_GRACKLE_MODE
+#endif // COOLING_GRACKLE
+  
   return PyUnicode_FromString(cooling_name);
 };
 
diff --git a/src/cooling_wrapper.c b/src/cooling_wrapper.c
index cb23170b544e455b181f8007277b7d55ace4ffa8..11f5bdb97a6d694e8f3c396553034d31bc723015 100644
--- a/src/cooling_wrapper.c
+++ b/src/cooling_wrapper.c
@@ -1,7 +1,18 @@
 #include "pyswiftsim_tools.h"
 #include "cooling_wrapper.h"
-
-
+#include <omp.h>
+
+
+/**
+ * @brief Initialize the cooling
+ *
+ * args is expecting pyswiftsim classes in the following order:
+ * SwiftParams, UnitSystem and PhysConst. 
+ *
+ * @param self calling object
+ * @param args arguments
+ * @return CoolingFunctionData
+ */
 PyObject* pycooling_init(PyObject* self, PyObject* args) {
   PyObject *pyparams;
   PyObject *pyus;
@@ -35,6 +46,52 @@ PyObject* pycooling_init(PyObject* self, PyObject* args) {
   return pycooling;
 }
 
+/**
+ * @brief Set the cooling element fractions
+ *
+ * @param xp The #xpart to set
+ * @param frac The numpy array containing the fractions (id, element)
+ * @param idx The id (in frac) of the particle to set
+ */
+void pycooling_set_fractions(struct xpart *xp, PyArrayObject* frac, const int idx) {
+  struct cooling_xpart_data *data = &xp->cooling_data;
+  data->metal_frac = *(float*)PyArray_GETPTR2(frac, idx, 0);
+
+#ifdef COOLING_GRACKLE
+#if COOLING_GRACKLE_MODE > 0
+  data->HI_frac = *(float*)PyArray_GETPTR2(frac, idx, 1);
+  data->HII_frac = *(float*)PyArray_GETPTR2(frac, idx, 2);
+  data->HeI_frac = *(float*)PyArray_GETPTR2(frac, idx, 3);
+  data->HeII_frac = *(float*)PyArray_GETPTR2(frac, idx, 4);
+  data->HeIII_frac = *(float*)PyArray_GETPTR2(frac, idx, 5);
+  data->e_frac = *(float*)PyArray_GETPTR2(frac, idx, 6);
+#endif // COOLING_GRACKLE_MODE
+#if COOLING_GRACKLE_MODE > 1
+  data->HM_frac = *(float*)PyArray_GETPTR2(frac, idx, 7);
+  data->H2I_frac = *(float*)PyArray_GETPTR2(frac, idx, 8);
+  data->H2II_frac = *(float*)PyArray_GETPTR2(frac, idx, 9);
+#endif // COOLING_GRACKLE_MODE
+#if COOLING_GRACKLE_MODE > 2
+  data->DI_frac = *(float*)PyArray_GETPTR2(frac, idx, 10);
+  data->DII_frac = *(float*)PyArray_GETPTR2(frac, idx, 11);
+  data->HDI_frac = *(float*)PyArray_GETPTR2(frac, idx, 12);
+#endif // COOLING_GRACKLE_MODE
+#endif // COOLING_GRACKLE
+
+}
+  
+/**
+ * @brief Compute cooling rate
+ *
+ * args is expecting pyswiftsim classes in the following order: 
+ * PhysConst, UnitSystem and CoolingFunctionData.
+ * Then two numpy arrays (density and specific energy) and an optional
+ * float for the time step
+ *
+ * @param self calling object
+ * @param args arguments
+ * @return cooling rate
+ */
 PyArrayObject* pycooling_rate(PyObject* self, PyObject* args) {
   import_array();
   
@@ -44,35 +101,40 @@ PyArrayObject* pycooling_rate(PyObject* self, PyObject* args) {
 
   PyArrayObject *rho;
   PyArrayObject *energy;
+  PyArrayObject *fractions = NULL;
 
   float dt = 1e-3;
 
   /* parse argument */
   if (!PyArg_ParseTuple(args,
-			"OOOOO|f",
+			"OOOOO|fOO",
 			&pypconst,
 			&pyus,
 			&pycooling,
 			&rho,
 			&energy,
-			&dt))
+			&dt,
+			&fractions
+			))
     return NULL;
 
   /* check numpy array */
   if (pytools_check_array(energy, 1, NPY_FLOAT) != SUCCESS)
-    {
-      return NULL;
-    }
+    return NULL;
 
   if (pytools_check_array(rho, 1, NPY_FLOAT) != SUCCESS)
-    {
-      return NULL;
-    }
+    return NULL;
+
+  if (fractions != NULL &&
+      pytools_check_array(fractions, 2, NPY_FLOAT) != SUCCESS)
+    return NULL;
 
   if (PyArray_DIM(energy, 0) != PyArray_DIM(rho, 0))
-    {
-      pyerror("Density and energy should have the same dimension");
-    }
+    pyerror("Density and energy should have the same dimension");
+
+  if (fractions != NULL &&
+      PyArray_DIM(fractions, 0) != PyArray_DIM(rho,0))
+    pyerror("Fractions should have the same first dimension than the others");
 
   size_t N = PyArray_DIM(energy, 0);
 
@@ -99,9 +161,13 @@ PyArrayObject* pycooling_rate(PyObject* self, PyObject* args) {
 #endif
 
   /* return object */
-  PyArrayObject *rate = PyArray_NewLikeArray(energy, NPY_ANYORDER, NULL, 1);
+  PyArrayObject *rate = PyArray_SimpleNew(PyArray_NDIM(energy), PyArray_DIMS(energy), NPY_FLOAT);
 
+  /* Release GIL */
+  Py_BEGIN_ALLOW_THREADS;
+  
   /* loop over all particles */
+#pragma omp for
   for(size_t i = 0; i < N; i++)
     {
       /* set particle data */
@@ -109,8 +175,11 @@ PyArrayObject* pycooling_rate(PyObject* self, PyObject* args) {
       float u = *(float*) PyArray_GETPTR1(energy, i);
       p.entropy = gas_entropy_from_internal_energy(p.rho, u);
 
-      cooling_first_init_part(&p, &xp, cooling);
-
+      if (fractions != NULL)
+	pycooling_set_fractions(&xp, fractions, i);
+      else
+	cooling_first_init_part(&p, &xp, cooling);
+	
       /* compute cooling rate */
       float *tmp = PyArray_GETPTR1(rate, i);
 #ifdef COOLING_GRACKLE
@@ -120,6 +189,9 @@ PyArrayObject* pycooling_rate(PyObject* self, PyObject* args) {
 #endif
     }
 
+  /* Acquire GIL */
+  Py_END_ALLOW_THREADS;
+
   return rate;
   
 }
diff --git a/src/cooling_wrapper.h b/src/cooling_wrapper.h
index 4fbdd8abbc746d5d42aa4e32baa707c7770aa602..95717e8ac5d2867b25ddcebe33548fe86ecfa6cc 100644
--- a/src/cooling_wrapper.h
+++ b/src/cooling_wrapper.h
@@ -3,30 +3,8 @@
 
 #include "pyswiftsim_tools.h"
 
-/**
- * @brief Initialize the cooling
- *
- * args is expecting pyswiftsim classes in the following order:
- * SwiftParams, UnitSystem and PhysConst. 
- *
- * @param self calling object
- * @param args arguments
- * @return CoolingFunctionData
- */
 PyObject* pycooling_init(PyObject* self, PyObject* args);
 
-/**
- * @brief Compute cooling rate
- *
- * args is expecting pyswiftsim classes in the following order: 
- * PhysConst, UnitSystem and CoolingFunctionData.
- * Then two numpy arrays (density and specific energy) and an optional
- * float for the time step
- *
- * @param self calling object
- * @param args arguments
- * @return cooling rate
- */
 PyArrayObject* pycooling_rate(PyObject* self, PyObject* args);
 
 #endif // __PYSWIFTSIM_COOLING_H__
diff --git a/src/wrapper.c b/src/wrapper.c
index 12e7ef2b12653c9648e91bdf1c727cddb6993395..46076c5cb627eaed51654bb02107fcaeefcbbfb6 100644
--- a/src/wrapper.c
+++ b/src/wrapper.c
@@ -30,7 +30,21 @@ static PyMethodDef wrapper_methods[] = {
    "Initialize cooling."},
 
   {"coolingRate", pycooling_rate, METH_VARARGS,
-   "Compute the cooling rate."},
+   "Compute the cooling rate.\n\n"
+   "Parameters\n"
+   "----------\n\n"
+   "pyconst: swift physical constant\n"
+   "pyus: swift unit system\n"
+   "cooling: swift cooling structure\n"
+   "rho: np.array\n"
+   "\t Mass density in pyus units\n"
+   "energy: np.array\n"
+   "\t Internal energy in pyus units\n"
+   "dt: float, optional\n"
+   "\t Time step in pyus units\n"
+   "fractions: np.array, optional\n"
+   "\t Fraction of each cooling element (including metals)\n"
+  },
 
   {"configGetCooling", config_get_cooling, METH_VARARGS,
    "Get the cooling type."},
diff --git a/test/test_cooling.py b/test/test_cooling.py
deleted file mode 100644
index f8bb2aeca7e0c7a2e9e71cf209681711c4aff66a..0000000000000000000000000000000000000000
--- a/test/test_cooling.py
+++ /dev/null
@@ -1,141 +0,0 @@
-#!/usr/bin/env python3
-
-from pyswiftsim import wrapper
-from pyswiftsim import structure
-
-from copy import deepcopy
-import numpy as np
-import matplotlib.pyplot as plt
-from astropy import units
-
-#
-# parameters
-#
-
-# adiabatic index
-gamma = 5. / 3.
-
-# swift params filename
-filename = "test/cooling.yml"
-
-# number of points
-N_rho = 1
-N_T = 100
-
-# density in atom / cm3
-if N_rho == 1:
-    rho = np.array([1.])
-else:
-    #rho = np.array([1.])
-    rho = np.logspace(-6, 4, N_rho)
-
-# temperature in K
-T = np.logspace(1, 9, N_T)
-
-# time step
-dt = units.Myr * 1e-5
-dt = dt.to("s") / units.s
-
-
-
-# generate grid
-rho, T = np.meshgrid(rho, T)
-shape = rho.shape
-rho = deepcopy(rho.reshape(-1))
-T = T.reshape(-1)
-
-#
-# swift init
-#
-
-# parse swift params
-print("Reading parameters")
-params = wrapper.parserReadFile(filename)
-# init units
-print("Initialization of the unit system")
-us, pconst = wrapper.unitSystemInit(params)
-# init cooling
-print("Initialization of the cooling")
-cooling = wrapper.coolingInit(params, us, pconst)
-
-
-#
-# Deal with units
-#
-
-# change units of rho and T
-# rho
-rho *= us.UnitLength_in_cgs**3 * pconst.const_proton_mass
-# specific energy
-energy = pconst.const_boltzmann_k * T / us.UnitTemperature_in_cgs
-energy /= (gamma - 1.) * pconst.const_proton_mass
-# time step
-dt /= us.UnitTime_in_cgs
-
-#
-# compute rate
-#
-
-# du / dt
-print("Computing cooling...")
-rate = wrapper.coolingRate(pconst, us, cooling,
-                           rho.astype(np.float32),
-                           energy.astype(np.float32),
-                           dt)
-print("Computing done")
-
-#
-# plot
-#
-
-
-# change units => cgs
-energy *= us.UnitLength_in_cgs**2 / us.UnitTime_in_cgs**2
-
-rate *= us.UnitLength_in_cgs**2 / us.UnitTime_in_cgs**3
-    
-if N_rho == 1 or N_T == 1:
-    plt.figure()
-    n = rho / (pconst.const_proton_mass * us.UnitLength_in_cgs**3)
-    proton_mass_in_cgs = pconst.const_proton_mass * us.UnitMass_in_cgs
-    lambda_ = rate * proton_mass_in_cgs / (n)
-    
-    if N_rho == 1:
-        plt.loglog(T, np.abs(lambda_))
-        plt.xlabel("Temperature [K]")
-    else:
-        plt.loglog(rho, np.abs(lambda_))
-        plt.xlabel("Density [atom / cm3]")
-    plt.ylabel("Rate")
-    
-else:
-    cooling_time = energy / rate
-    cooling_length = np.sqrt(gamma * (gamma-1.) * energy) * cooling_time
-
-    cooling_length = np.log10(np.abs(cooling_length) / units.kpc.to('cm'))
-    
-    # reshape
-    rho = rho.reshape(shape)
-    T = T.reshape(shape)
-    energy = energy.reshape(shape)
-    cooling_length = cooling_length.reshape(shape)
-
-    _min = -7
-    _max = 7
-    N_levels = 100
-    levels = np.linspace(_min, _max, N_levels)
-    plt.figure()
-    plt.contourf(rho, T, cooling_length, levels)
-    plt.xlabel("Density [atom/cm3]")
-    plt.ylabel("Temperature [K]")
-
-    ax = plt.gca()
-    ax.set_xscale("log")
-    ax.set_yscale("log")
-    
-    cbar = plt.colorbar()
-    tc = np.arange(_min, _max, 1.5)
-    cbar.set_ticks(tc)
-    cbar.set_ticklabels(tc)
-
-plt.show()