#!/usr/bin/env python3
import numpy as np
import h5py as h5
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import scipy.interpolate as sci
from tqdm import tqdm
import sys
import time
from shutil import copyfile
from unyt import unyt_array

np.random.seed(420)
# import swiftsimio.visualisation as swvis
from swiftsimio.visualisation.smoothing_length_generation import (
    generate_smoothing_lengths,
)

# definition of constants
G = 4.3e4  # cm^2 s^-2
mu = 0.62
proton_mass = 1.6726e-24  # gram
kB = 1.38e-16  # erg / K
gamma = 5.0 / 3.0


def vcirc_dmhalo(r, a, Mtotal):
    return G * Mtotal * r / (r + a) ** 2


def Mp_hern(r, a, Mtotal):
    return Mtotal * r ** 2 / (r + a) ** 2


def M_stellar_bulge(coordinates, masses, radius_array):
    r2 = coordinates[:, 0] ** 2 + coordinates[:, 1] ** 2 + coordinates[:, 2] ** 2
    r = r2 ** 0.5

    mass_encl = np.zeros(len(radius_array))
    for i in range(0, len(radius_array)):
        mask = r < radius_array[i]
        mass_encl[i] = np.sum(masses[mask])

    return mass_encl


def gas_fraction_rho500(m500, alpha, mt, omegab=0.0463, omegam=0.2793):
    return 0.5 * omegab / omegam * (1 + np.tanh(np.log10(m500 / mt) / alpha))


def mean_density(r_array, rho_avg, rho_500, rho_200, rho_c):
    plt.plot(r_array, rho_avg)
    plt.plot(r_array, np.ones(len(r_array)) * rho_500, linestyle="--")
    plt.plot(r_array, np.ones(len(r_array)) * rho_200, linestyle="--")
    plt.plot(r_array, np.ones(len(r_array)) * rho_c, linestyle="--")
    plt.xlabel("Radius [$\\rm kpc$]")
    plt.ylabel("Average density [$\\rm M_\odot \\rm kpc^{-3}$]")
    plt.xscale("log")
    plt.yscale("log")
    plt.savefig("Mean_density_guess.png")
    plt.close()


def find_gas_mass(Mdm_encl, Mstar_encl, r_array):
    # Calculate the point at which we are close to r500
    rho_avg = (Mdm_encl + Mstar_encl) / (4 * np.pi / 3 * r_array ** 3)
    rho_c = 130
    rho_500 = rho_c * 500  # solar mass / kpc^3
    rho_200 = rho_c * 200
    mean_density(r_array, rho_avg, rho_500, rho_200, rho_c)

    idx_min_mass_500 = np.argmin(np.abs(rho_avg - rho_500))

    r500_guess = r_array[idx_min_mass_500]
    m500_guess = Mdm_stars_encl[idx_min_mass_500]
    m500_gas_fraction = gas_fraction_rho500(m500_guess, alpha, mt)
    mgas_mass_guess = m500_gas_fraction * m500_guess
    return mgas_mass_guess


def get_gas_T(Tmin, circ_v_tot):
    return Tmin + mu * proton_mass / kB / gamma * circ_v_tot


def get_gas_T_center(Tmin, rmin, r, a):
    return np.log10(Tmin / (1 + np.exp((r - a * rmin) / rmin)))


def get_gas_T_center2(Tmin, rmin, r):
    result = np.zeros(len(r))
    result[r < rmin] = Tmin
    result[r >= rmin] = Tmin * (1 - (r[r >= rmin] - rmin) / (5 * rmin))
    result[result < 0] = 0
    print(result)
    return np.log10(result)


def particle_mass(r_array, rtyp):
    mass_factor = np.ones(len(r_array))
    if mass_scheme != "default":
        mask = r_array > rtyp
        mass_factor[mask] = ms_a * (r_array[mask] / rtyp) ** ms_n - ms_b
    return mass_factor


# Debackere model
alpha = 1.35
mt = 10 ** 13.94
amount_of_sigma_used = 0.0

# Parameters of the model
M200 = 1e13
bulge_fraction = 0.01
Mdm = M200 * (1 - 0.01)
a_hernquist = 82.3903  # kpc
R200 = 442.73
V200 = 311.682
mu_angular = 1.25
lambda_angular = 0.05
s_angular = 1.0
mBH = 10 ** 8.5

# set initial value in model
Tmin_model = 1e6  # K
Tscale_length = 15.0  # kpc

# get the arguments of the code
file_name = str(sys.argv[1])
file_name_save = str(sys.argv[2])
mass_scheme = str(sys.argv[3])

if mass_scheme == "default":
    ms_n = 1
    ms_a = 1
    print("Generate ICs with constant mass")
elif mass_scheme == "linear1":
    ms_n = 1
    ms_a = 1
    print("Generate ICs with linear1")
elif mass_scheme == "linear5":
    ms_n = 1
    ms_a = 5
    print("Generate ICs with linear5")
elif mass_scheme == "linear10":
    ms_n = 1
    ms_a = 10
    print("Generate ICs with linear10")
elif mass_scheme == "quadratic":
    ms_n = 2
    ms_a = 1
    print("Generate ICs with quadratic")
else:
    raise TypeError("Invalid choice of mass scheme!")

ms_b = ms_a - 1

copyfile(file_name, file_name_save)

# load the data
f = h5.File(file_name_save, "r+")

Coordinates = f["/PartType4/Coordinates"][:, :]
Masses = f["/PartType4/Masses"][:] * 1e10
Velocities = f["/PartType4/Velocities"][:, :]
ids = f["/PartType4/ParticleIDs"][:]
ids_update = ids + 1
ids_data = f["/PartType4/ParticleIDs"]
ids_data[...] = ids_update

boxsize = f["/Header"].attrs["BoxSize"]

Coordinates[:, 0] -= boxsize[0] / 2.0
Coordinates[:, 1] -= boxsize[1] / 2.0
Coordinates[:, 2] -= boxsize[2] / 2.0

if 3 * R200 > 1000:
    rmax = 1000
else:
    rmax = 3 * R200

r_array = np.linspace(1.0, rmax, int(1e4))  # in kpc


# calculate the assumed dark matter and stellar mass profiles
Mdm_encl = Mp_hern(r_array, a_hernquist, Mdm)
Mstar_encl = M_stellar_bulge(Coordinates, Masses, r_array)
Mdm_stars_encl = Mdm_encl + Mstar_encl

# find the gas mass
mgas_guess = find_gas_mass(Mdm_encl, Mstar_encl, r_array)

# Calculate the circular velocity assuming gas has no contribution
circular_velocity = G * Mstar_encl / r_array
circular_vel_dm = vcirc_dmhalo(r_array, a_hernquist, Mdm)
circ_v_tot = circular_velocity + circular_vel_dm

# get the temperature profile
log_T_profile = np.log10(get_gas_T(0, circ_v_tot))
log_T_center_65 = get_gas_T_center(Tmin_model, Tscale_length, r_array, 2.5)

plt.plot(
    np.log10(r_array),
    np.log10(10 ** log_T_profile + 10 ** log_T_center_65),
    label="Tmin = $10^{6.0}$K",
)
plt.xlabel("log radius [kpc]")
plt.ylabel("log temperature [K]")
plt.legend()
plt.savefig("Updated_temperature_profiles.png")
plt.close()

T4 = 10 ** log_T_profile + 10 ** log_T_center_65

plt.plot(
    np.log10(r_array),
    circ_v_tot / (kB * gamma * (T4)) * mu * proton_mass,
    label="Tmin = $10^{6.0}$K",
)
plt.xlabel("log radius [kpc]")
plt.ylabel("circular velocity / sound speed")
plt.ylim(0, 1.01)
plt.legend()
plt.savefig("Ratio_circ_velocity_sound_speed.png")
plt.close()

# calculate the new pressure profile

log_pressure_array = np.zeros(len(r_array))

f_circ_v = sci.interp1d(r_array, circ_v_tot)
f_T = sci.interp1d(r_array, T4)

rho0 = 10
P0 = rho0 * kB * T4[0]
log_pressure_array[0] = np.log10(P0)

for i in range(1, len(r_array)):
    r_mean = 10 ** ((np.log10(r_array[i]) + np.log10(r_array[i - 1])) / 2.0)
    delta_logr = np.log10(r_array[i]) - np.log10(r_array[i - 1])
    log_pressure_array[i] = (
        log_pressure_array[i - 1]
        - f_circ_v(r_mean) * mu * proton_mass / (kB * f_T(r_mean)) * delta_logr
    )

plt.plot(np.log10(r_array), log_pressure_array)
plt.xlabel("log radius [kpc]")
plt.ylabel("log pressure [Barye]")
plt.savefig("Pressure_profile_guess.png")
plt.close()

plt.plot(np.log10(r_array), log_pressure_array - np.log10(kB) - np.log10(T4))
plt.xlabel("log radius [kpc]")
plt.ylabel("log density [$\\rm cm^{-3}$]")

plt.savefig("density_profile_guess.png")
plt.close()


# calculate the gas fraction
rho_avg = (Mdm_encl + Mstar_encl) / (4 * np.pi / 3 * r_array ** 3)
rho_c = 130
rho_500 = rho_c * 500  # solar mass / kpc^3
rho_200 = rho_c * 200
mean_density(r_array, rho_avg, rho_500, rho_200, rho_c)

idx_min_mass_500 = np.argmin(np.abs(rho_avg - rho_500))
r500_guess = r_array[idx_min_mass_500]
m500_guess = Mdm_stars_encl[idx_min_mass_500]
m500_gas_fraction = gas_fraction_rho500(m500_guess, alpha, mt)
m500_gas_fraction_1std = 0.02 / m500_gas_fraction
hydrostatic_gas_fraction = (
    m500_gas_fraction + amount_of_sigma_used * m500_gas_fraction_1std
)
mgas_mass_guess = (
    m500_gas_fraction + amount_of_sigma_used * m500_gas_fraction_1std
) * m500_guess

print(f"The hydrostatic gas fraction at R500: {hydrostatic_gas_fraction:1.4f}")

# calculate the encl mass using the assumed normalization
log_rho = log_pressure_array - np.log10(kB) - np.log10(T4)
rho = 10 ** log_rho

dr = r_array[1] - r_array[0]
mass_gas = rho * mu * proton_mass * 4 * np.pi * r_array ** 2 * dr * 3.0857e21 ** 3
mass_gas_encl = np.cumsum(mass_gas) / 1.9891e33

# rescale the density profile
current_gas_mass_at_r500 = mass_gas_encl[idx_min_mass_500]

rho /= current_gas_mass_at_r500
rho *= mgas_mass_guess

mass_gas /= current_gas_mass_at_r500
mass_gas *= mgas_mass_guess

mass_gas_encl = np.cumsum(mass_gas) / 1.9891e33

rho_avg_new = (Mdm_encl + Mstar_encl + mass_gas_encl) / (4 * np.pi / 3 * r_array ** 3)
idx_min_mass_500_new_guess = np.argmin(np.abs(rho_avg_new - rho_500))
r500_new_guess = r_array[idx_min_mass_500_new_guess]

# check if we are not too much deviated from the actual virial mass
if np.abs(r500_new_guess / r500_guess - 1.0) > 0.05:
    print("Warning!! iteration might be better!")

# Calculate the angular momentum stuff convert to cgs
j0_angular = (
    2 ** 0.5
    * V200
    * R200
    * lambda_angular
    / (-mu_angular * np.log(1.0 - 1.0 / mu_angular) - 1.0)
    * 3.086e26
)

jmax_angular = j0_angular / (mu_angular - 1.0)

# calculate the particle mass at different radii
particle_mass = particle_mass(r_array, 100)


# calculate the amount of gas particles
Mres = Masses[0]
N_gas_particles_per_shell = mass_gas / 1.9891e33 / (Mres * particle_mass)

# print(Mres)
# Nparticles = int(mass_gas_encl[-1] / Mres)
Nparticles = int(np.cumsum(N_gas_particles_per_shell)[-1])

mass_N = mass_gas / 1.9891e33 / Mres
max_N = mass_N[-1]
max_N = np.max(N_gas_particles_per_shell)

f_mass = sci.interp1d(r_array, mass_N, kind="linear")
f_T = sci.interp1d(r_array, T4, kind="linear")
f_circ = sci.interp1d(r_array, circ_v_tot, kind="linear")
f_Mencl = sci.interp1d(r_array, Mdm_stars_encl + mass_gas_encl, kind="linear")

f_particle_mass = sci.interp1d(r_array, Mres * particle_mass, kind="linear")
f_N_particles_per_shell = sci.interp1d(
    r_array, N_gas_particles_per_shell, kind="linear"
)

j_result = jmax_angular * (f_Mencl(r_array) / M200) ** s_angular
j_result2 = jmax_angular * (f_Mencl(r_array) / M200) ** 1.2
j_result3 = jmax_angular * (f_Mencl(r_array) / M200) ** 0.7
j_result_update = (
    j0_angular
    * (f_Mencl(r_array) / M200)
    * (1.0 / (mu_angular - f_Mencl(r_array) / M200))
)

v_found = np.log10(j_result_update / (r_array * 3.086e21))
last_v_value = v_found[r_array < R200][-1]
v_found[r_array > R200] = last_v_value + np.log10(R200 / r_array[r_array > R200])

j_result_update2 = (
    j0_angular * (f_Mencl(r_array) / M200) * (1.0 / (1.05 - f_Mencl(r_array) / M200))
)
plt.plot(np.log10(r_array), v_found - 5, label="angular momentum used")
plt.plot(np.log10(r_array), 0.5 * np.log10(circ_v_tot) - 5.0, label="circular velocity")
plt.plot(
    np.log10(r_array), np.log10(0.05 * circ_v_tot ** 0.5) - 5.0, label="Simple model"
)
plt.plot(
    np.log10(r_array),
    np.log10(np.sqrt(2) * V200 * R200 / r_array * 0.05),
    label="Constant j",
)
plt.ylim(0, 3)
plt.xlabel("log radius [kpc]")
plt.ylabel("log velocity [km/s]")
plt.legend()
plt.savefig("Rotation_velocities.png")
plt.close()

f_rotation_v = sci.interp1d(r_array, 10 ** v_found, kind="cubic")

T_sampled = np.zeros(Nparticles)
vcirc_sampled = np.zeros(Nparticles)

realized_particles = 0
total_tries = 0

phi_random = np.random.uniform(0, 2 * np.pi, size=Nparticles)
theta_random = np.arccos(np.random.uniform(-1, 1, size=Nparticles))

# try length
array_length = Nparticles * 10

# get random radii
r_values_guess = np.random.uniform(r_array[0], r_array[-1], array_length)
# get random y values
y_value = np.random.uniform(0, max_N, array_length)
# get maximum y value expected at radii
y_max_value = f_mass(r_values_guess)
# allowed to use this numer?
mask = y_max_value >= y_value

# make the array of the radii that we can use according to rejection
# sampling
r_values_using = r_values_guess[mask][:Nparticles]

# get the particle mass + temperature
particle_mass = f_particle_mass(r_values_using)
T_particles = f_T(r_values_using)

# find the cylindrical radius and the rotation velocity
R_using = r_values_using * np.sin(theta_random)
v_rot = f_rotation_v(np.maximum(r_values_using, r_array[0]))

# get the internal energy
internal_energy = kB * T_particles / (0.62 * proton_mass * (gamma - 1))

# get the position + velocity
x_gas = r_values_using * np.sin(theta_random) * np.cos(phi_random)
y_gas = r_values_using * np.sin(theta_random) * np.sin(phi_random)
z_gas = r_values_using * np.cos(theta_random)

x_gas += boxsize[0] / 2.0
y_gas += boxsize[1] / 2.0
z_gas += boxsize[2] / 2.0


velocity_x = -vcirc_sampled * np.sin(phi_random)
velocity_y = vcirc_sampled * np.cos(phi_random)
velocity_z = 0 * vcirc_sampled

coords = np.zeros((len(r_values_using), 3))
coords[:, 0] = x_gas
coords[:, 0] = y_gas
coords[:, 0] = z_gas

coords = unyt_array(coords, "kpc")
boxsize = unyt_array(boxsize, "kpc")

number_of_gas_particles = len(x_gas)
print(f"Number of gas particles = {number_of_gas_particles:d}")
print("Calculate the smoothing lengths:")
smoothing_lengths = generate_smoothing_lengths(
    coordinates=coords, boxsize=boxsize, kernel_gamma=1.936492, neighbours=57
)
print("Finished calculating the smoothing lengths")

# store everything in the IC file:
N = len(x_gas)

grp = f.create_group("/PartType0")

coords = np.zeros((N, 3))
coords[:, 0] = x_gas
coords[:, 1] = y_gas
coords[:, 2] = z_gas

ds = grp.create_dataset("Coordinates", (N, 3), "d")
ds[()] = coords

v = np.zeros((N, 3))
v[:, 0] = velocity_x * 1e-5
v[:, 1] = velocity_y * 1e-5
v[:, 2] = velocity_z * 1e-5

ds = grp.create_dataset("Velocities", (N, 3), "f")
ds[()] = v

ds = grp.create_dataset("Masses", (N,), "f")
ds[()] = particle_mass / 1e10

ds = grp.create_dataset("SmoothingLength", (N,), "f")
ds[()] = smoothing_lengths.value

ds = grp.create_dataset("InternalEnergy", (N,), "f")
ds[()] = internal_energy * 1e-10

new_ids = ids[-1] + np.arange(1, N + 1, 1)
ds = grp.create_dataset("ParticleIDs", (N,), "L")
ds[()] = new_ids

# Add the black hole

grp = f.create_group("/PartType5")
ds = grp.create_dataset("Coordinates", (1, 3), "d")
ds[()] = boxsize / 2.0

v = np.zeros((1, 3))
ds = grp.create_dataset("Velocities", (1, 3), "f")
ds[()] = v

ds = grp.create_dataset("SmoothingLength", (1,), "f")
ds[()] = np.ones(1)

ds = grp.create_dataset("ParticleIDs", (1,), "L")
ds[()] = np.ones(1) * int(1)

ds = grp.create_dataset("Masses", (1,), "f")
ds[()] = np.ones(1) * 10 ** 8.5 / 1e10

ds = grp.create_dataset("SubgridMasses", (1,), "f")
ds[()] = np.ones(1) * mBH / 1e10

ds = grp.create_dataset("EnergyReservoir", (1,), "f")
ds[()] = np.zeros(1)

initial_spin = 0.2
ds = grp.create_dataset("Spins", (1,), "f")
ds[()] = np.ones(1) * initial_spin

ds = grp.create_dataset("AngularMomentumDirections", (1, 3), "f")
ds[()] = np.array([[0], [0], [1]]).T


# Also update the header
header = f["Header"]
ThisFile = header.attrs["NumPart_ThisFile"]
new_values = np.zeros(6)
new_values[4] = ThisFile[4]
new_values[0] = N
new_values[5] = 1
Total = header.attrs["NumPart_Total"]
new_values2 = np.zeros(6)
new_values2[4] = Total[4]
new_values2[0] = N
new_values2[5] = 1
boxsize = header.attrs["BoxSize"]

del f["Header"]
grp = f.create_group("/Header")
grp.attrs["BoxSize"] = boxsize
grp.attrs["NumPart_Total"] = new_values2
grp.attrs["NumPart_Total_HighWord"] = [0, 0, 0, 0, 0, 0]
grp.attrs["NumPart_ThisFile"] = new_values
grp.attrs["Time"] = 0.0
grp.attrs["NumFilesPerSnapshot"] = 1
grp.attrs["MassTable"] = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
grp.attrs["Flag_Entropy_ICs"] = [0, 0, 0, 0, 0, 0]
grp.attrs["Dimension"] = 3

f.close()
