###############################################################################
# This file is part of SWIFT.
# Copyright (c) 2020 Evgenii Chaikin (chaikin@strw.leidenuniv.nl)
#
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
##############################################################################

import h5py
import numpy as np
import sys

# Generates a SWIFT IC file with a constant density and pressure and one BH particle in the middle of the box

mass = float(sys.argv[1])  # Get SPH mass resolution in units of solar mass
res = 32  # Number of particles res^3
dens = 1.0  # Gas number density [number of hydrogen particles per cm^3]

# Read id, position and h from glass
glass = h5py.File("glassCube_{:d}.hdf5".format(res), "r")  # Loading glass file
pos = glass["/PartType0/Coordinates"][:, :]  # Gas particle coordinates


# removing ill-positioned particles
arg = np.where(
    (pos[:, 1] < 0.0)
    | (pos[:, 2] < 0.0)
    | (pos[:, 0] < 0.0)
    | (pos[:, 1] > 1.0)
    | (pos[:, 2] > 1.0)
    | (pos[:, 0] > 1.0)
)


h = glass["/PartType0/SmoothingLength"][:]  # Loading smoothing lengths
numPart = np.size(h) - np.size(arg)  # Number of gas particles


# Removing "bad" particles
pos = np.swapaxes(
    np.array(
        [
            np.delete(pos[:, 0], arg),
            np.delete(pos[:, 1], arg),
            np.delete(pos[:, 2], arg),
        ]
    ),
    0,
    1,
)
h = np.delete(h, arg)


# Global parameters
T = 8000.0  # Initial Temperature [K]
gamma = 5.0 / 3.0  # Gas adiabatic index
h_frac = 0.737  # Hydrogen mass fraction
mu = 0.8  # Mean molecular weight
rho = 3.45e7 * dens  # Density of the gas in code units
boxSize = np.round(
    (mass / rho * numPart) ** (1.0 / 3.0), 3
)  # Compute boxsize based on the density and mass resolution
periodic = 1  # 1 For periodic box
fileName = "example_{:d}_{:.1e}.hdf5".format(
    res, mass
)  # Create name for the output file


# Modifying quantities that have dimensions of length
pos *= boxSize
h *= boxSize

# Defining units
m_h_cgs = 1.67e-24  # proton mass
k_b_cgs = 1.38e-16  # Boltzmann constant
unit_length = 3.08567758e21  # kpc
unit_mass = 1.98848e33  # solar mass
unit_time = 3.0857e16  # ~ Gyr


# Gas properties
v = np.zeros((numPart, 3))  # Velocity
u = np.zeros(numPart)  # Internal energy
m = np.zeros(numPart)  # Mass
ids = np.linspace(1, numPart, numPart) + 100000  # IDs


# Compute and save internal energy of the gas
internalEnergy = k_b_cgs * T * mu / ((gamma - 1.0) * m_h_cgs)
internalEnergy *= (unit_time / unit_length) ** 2
u[:] = internalEnergy


# Mass of the gas particles
m[:] = mass * np.ones(numPart)


# Creating N_BHs black-hole particle/s
N_BHs = 1

# Position of the BH particles (in the centre of the box)
pos_BH = np.zeros((N_BHs, 3))
pos_BH += boxSize / 2.0

# Zero velocity of the BHs
vel_BH = np.zeros((N_BHs, 3))
vel_BH[:, :] = 0.0

# Masses of the BH particles
mult_fact = 1e1  # BH is 10 times more massive than gas particles
m_BH = mass * mult_fact * np.ones((N_BHs, 1))

# BH particle IDs
ids_BH = np.reshape(1 + np.arange(N_BHs), (N_BHs, 1))


# File
f = h5py.File(fileName, "w")

# Header
grp = f.create_group("/Header")
grp.attrs["BoxSize"] = [boxSize, boxSize, boxSize]
grp.attrs["NumPart_Total"] = [numPart, 0, 0, 0, 0, N_BHs]
grp.attrs["NumPart_Total_HighWord"] = [0, 0, 0, 0, 0, 0]
grp.attrs["NumPart_ThisFile"] = [numPart, 0, 0, 0, 0, N_BHs]
grp.attrs["Time"] = [0.0]
grp.attrs["NumFilesPerSnapshot"] = [1]
grp.attrs["MassTable"] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
grp.attrs["Flag_Entropy_ICs"] = [0]
grp.attrs["Dimension"] = 3

# Runtime parameters
grp = f.create_group("/RuntimePars")
grp.attrs["PeriodicBoundariesOn"] = periodic

# Units
grp = f.create_group("/Units")
grp.attrs["Unit length in cgs (U_L)"] = unit_length
grp.attrs["Unit mass in cgs (U_M)"] = unit_mass
grp.attrs["Unit time in cgs (U_t)"] = unit_time
grp.attrs["Unit current in cgs (U_I)"] = 1.0
grp.attrs["Unit temperature in cgs (U_T)"] = 1.0

# Gas Particle group
grp = f.create_group("/PartType0")
grp.create_dataset("Coordinates", data=pos, dtype="d")
grp.create_dataset("Velocities", data=v, dtype="f")
grp.create_dataset("Masses", data=m, dtype="f")
grp.create_dataset("SmoothingLength", data=h, dtype="f")
grp.create_dataset("InternalEnergy", data=u, dtype="f")
grp.create_dataset("ParticleIDs", data=ids, dtype="L")

# Stellar particle group
grp = f.create_group("/PartType5")
ds = grp.create_dataset("Velocities", (N_BHs, 3), "f")
ds[()] = vel_BH
ds = grp.create_dataset("Masses", (N_BHs, 1), "f")
ds[()] = m_BH
ds = grp.create_dataset("ParticleIDs", (N_BHs, 1), "L")
ds[()] = ids_BH
ds = grp.create_dataset("Coordinates", (N_BHs, 3), "d")
ds[()] = pos_BH
ds = grp.create_dataset("SmoothingLength", (N_BHs, 1), "f")
ds[()] = np.median(h) * np.ones((N_BHs, 1))

# Close the output file
f.close()

# Some output
print("Initial condition have been generated! \n")
print("-------------------------------------------------")
print("Box size                    : {:.3e} kpc".format(boxSize))
print("Number of gas particles     : {:d}".format(numPart))
print("Gas number density          : {:.3e} cm^-3".format(dens))
print("Average gas-particle mass   : {:.3e} M_\odot".format(np.mean(m)))
print("Min gas-particle mass       : {:.3e} M_\odot".format(np.min(m)))
print("Max gas-particle mass        : {:.3e} M_\odot".format(np.max(m)))
print("Number of black-hole particles : {:d}".format(N_BHs))
print("BH-particle mass       : {:.3e} M_\odot".format(mass * mult_fact))
