import numpy as np
from swiftsimio import load
import unyt
from unyt import g, cm, s
import h5py
from swiftsimio.visualisation.slice import slice_gas

##########################################################
#    Read in the simulation output
##########################################################

t = []
Pgas = []
Pstar = []
Mstar = []
Zstar = []
nHgas_star = []
Tgas_star = []

nbins = 128


for i in range(0, 260):
    data = load("output_%4.4i.hdf5" % (i))
    GasMomenta = data.gas.stellar_momenta_received.astype("float64")
    GasMomenta.convert_to_units("cm*g/s")

    StarMomenta = data.stars.stellar_momenta_received.astype("float64")
    StarMomenta.convert_to_units("cm*g/s")

    Time = data.metadata.time.to("Myr")

    StarMass = data.stars.initial_masses.astype("float64")
    StarMass.convert_to_units("g")

    # get density and temperature close to the star position (center of the box)
    data.gas.mass_weighted_temps = data.gas.masses * data.gas.temperatures
    mass_map = slice_gas(
        data, slice=0.5, resolution=nbins, project="masses", parallel=True
    )

    mass_weighted_temp_map = slice_gas(
        data, slice=0.5, resolution=nbins, project="mass_weighted_temps", parallel=True
    )

    temp_map = mass_weighted_temp_map / mass_map

    from unyt import K

    temp_map.convert_to_units(K)
    Tgas_star.append(np.log10(temp_map[int(nbins / 2), int(nbins / 2)].value))

    from unyt import g, cm, mh

    mass_map.convert_to_units(g / cm ** 3)
    mass_map = 0.72 * mass_map / mh
    mass_map.convert_to_units(1.0 / cm ** 3)

    nHgas_star.append(np.log10(mass_map[int(nbins / 2), int(nbins / 2)].value))

    t.append(Time)
    Pgas.append(np.sum(GasMomenta) / StarMass[0])
    Pstar.append(np.sum(StarMomenta) / StarMass[0])
    Mstar.append(StarMass[0])
    Zstar.append(data.stars.metal_mass_fractions[0])

    # print ('{:04.2f}, {:08.2e}, {:08.2e}, {:08.2e}'.format(t[-1], Mstar[-1], Pgas[-1], Pstar[-1]))

##########################################################
#    Read in the tables
##########################################################

# hardcoded in the simulation for testing (see README file for testing)

with h5py.File("Early_stellar_feedback.hdf5", "r") as f:
    RadPres = f["RadiationPressure/Absorption/logPcumulative_Total"][
        ()
    ]  # [temperature, metallicity, density, stellar age]
    StWind = f["StellarWinds/CumulativeMomentum"][()]  # [metallicity, stellar age]
    DensityBins = f["Header/DensityBins"][()]
    MetallicityBins = f["Header/MetallicityBins"][()]
    StellarAgeBins = f["Header/StellarAgeBins"][()]
    TemperatureBins = f["Header/TemperatureBins"][()]
    Zsol = f["Header/Constants/SolarMetallicity"][()]

logZZsol = np.log10(Zstar / Zsol)

from scipy.interpolate import RegularGridInterpolator

func_st_wind = RegularGridInterpolator((MetallicityBins, StellarAgeBins), StWind)

func_rad_pres = RegularGridInterpolator(
    (TemperatureBins, MetallicityBins, DensityBins, StellarAgeBins), RadPres
)

log10Pcum_radpres = np.full(len(t), -50.0)
log10Pcum_stwind = np.full(len(t), -50.0)

for i in range(1, len(t)):
    t_start = float(t[i - 1].value)
    t_end = float(t[i].value)
    # add momentum that was added between t_start and t_end
    if t_start >= StellarAgeBins[0]:
        logPcum_start = func_rad_pres(
            [Tgas_star[i], logZZsol[i], nHgas_star[i], t_start]
        )[0]
        logPcum_end = func_rad_pres([Tgas_star[i], logZZsol[i], nHgas_star[i], t_end])[
            0
        ]
        deltaP = np.power(10.0, logPcum_end) - np.power(10.0, logPcum_start)

    elif t_end >= StellarAgeBins[0]:
        logPcum_end = func_rad_pres([Tgas_star[i], logZZsol[i], nHgas_star[i], t_end])[
            0
        ]
        deltaP = np.power(10.0, logPcum_end)

    else:
        deltaP = 0.0

    print("%.4f Myr: delta P = %.4e" % (t[i], deltaP))
    log10Pcum_radpres[i] = np.log10(
        max(np.power(10.0, log10Pcum_radpres[i - 1]) + deltaP, 1.0e-50)
    )


for i in range(1, len(t)):
    t_loc = float(t[i].value)
    if t_loc >= StellarAgeBins[0]:
        log10Pcum_stwind[i] = func_st_wind([logZZsol[i], t_loc])[0]


print(
    "%.4f\t%.4f\t%.4f"
    % (np.log10(Pgas[-1]), log10Pcum_stwind[-1], log10Pcum_radpres[-1])
)


import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec


fig = plt.figure()
gs = gridspec.GridSpec(1, 1)

ax = plt.subplot(gs[0])
ax.set_ylim(5.5, 7.75)
ax.set_title("Based on approximate gas density and temperature")
ax.set_xlabel("t [Myr]")
ax.set_ylabel("log P(<t) [g cm s$^{-1}$ per g star mass]")

ax.plot(t, np.log10(Pgas), label="Simulation", linewidth=5, color="lightgrey")
ax.plot(
    t,
    np.log10(np.power(10.0, log10Pcum_stwind) + np.power(10.0, log10Pcum_radpres)),
    label="Total (table)",
    color="black",
)
ax.plot(t, log10Pcum_stwind, label="Stellar Wind (table)", color="#CC0066")
ax.plot(t, log10Pcum_radpres, label="Radiation pressure (table)", color="#00CCCC")

ax.axvline(250.0)

ax.legend()
fig.savefig("momenta.png", dpi=150)
