import numpy as np
import matplotlib.pylab as plt
from matplotlib.ticker import AutoMinorLocator
import h5py as h5
import sys
import glob

constants = {
    "NEWTON_GRAVITY_CGS": 6.67408e-8,
    "SOLAR_MASS_IN_CGS": 1.98848e33,
    "PARSEC_IN_CGS": 3.08567758e18,
    "PROTON_MASS_IN_CGS": 1.672621898e-24,
    "BOLTZMANN_IN_CGS": 1.38064852e-16,
    "YEAR_IN_CGS": 3.15569252e7,
}

gas_energy_from_AGN = []

BH_energy_reservoirs = []
BH_N_Of_AGN_events = []
BH_N_Of_heating_events = []
BH_energy = []
BH_AccretionLimitedTimestep = []
BH_Timestep = []

snp_times = []


def fixlogax(ax, a="x"):
    if a == "x":
        labels = [item.get_text() for item in ax.get_xticklabels()]
        positions = ax.get_xticks()
        for i in range(len(positions)):
            labels[i] = "$10^{" + str(int(np.log10(positions[i]))) + "}$"
        if np.size(np.where(positions == 1)) > 0:
            labels[np.where(positions == 1)[0][0]] = "$1$"
        if np.size(np.where(positions == 10)) > 0:
            labels[np.where(positions == 10)[0][0]] = "$10$"
        if np.size(np.where(positions == 0.1)) > 0:
            labels[np.where(positions == 0.1)[0][0]] = "$0.1$"
        ax.set_xticklabels(labels)
    if a == "y":
        labels = [item.get_text() for item in ax.get_yticklabels()]
        positions = ax.get_yticks()
        for i in range(len(positions)):
            labels[i] = "$10^{" + str(int(np.log10(positions[i]))) + "}$"
        if np.size(np.where(positions == 1)) > 0:
            labels[np.where(positions == 1)[0][0]] = "$1$"
        if np.size(np.where(positions == 10)) > 0:
            labels[np.where(positions == 10)[0][0]] = "$10$"
        if np.size(np.where(positions == 0.1)) > 0:
            labels[np.where(positions == 0.1)[0][0]] = "$0.1$"
        ax.set_yticklabels(labels)


def load_data(number_of_snp: int = 200) -> None:
    """
    :param number_of_snp: Total number of snapshots to load starting from zero
    :return: None
    """

    # Connect time-bins to time-steps
    timepstes = np.loadtxt(glob.glob("timesteps_*.txt")[0], skiprows=20)
    time_bin = timepstes[0, 5]
    time_step = timepstes[0, 4]
    print(
        "Time bin {:.0f} corresponds to time-step {:.3e} [internal units]".format(
            time_bin, time_step
        )
    )

    for snapshot in range(number_of_snp):

        with h5.File("output_{:04d}.hdf5".format(snapshot), "r") as f:

            # Units
            unit_length_in_cgs = f["/Units"].attrs["Unit length in cgs (U_L)"]
            unit_mass_in_cgs = f["/Units"].attrs["Unit mass in cgs (U_M)"]
            unit_time_in_cgs = f["/Units"].attrs["Unit time in cgs (U_t)"]
            unit_energy_in_cgs = (
                unit_mass_in_cgs * (unit_length_in_cgs / unit_time_in_cgs) ** 2
            )

            snp_times.append(
                f["/Header"].attrs["Time"][0]
                * unit_time_in_cgs
                / constants["YEAR_IN_CGS"]
                / 1e6
            )

            gas_energy_from_AGN.append(
                np.sum(
                    f["/PartType0/EnergiesReceivedFromAGNFeedback"][:]
                    * unit_energy_in_cgs
                    / 1e51
                )
            )

            BH_energy_reservoirs.append(
                f["/PartType5/EnergyReservoirs"][:] * unit_energy_in_cgs / 1e51
            )
            BH_N_Of_AGN_events.append(f["/PartType5/NumberOfAGNEvents"][:])
            BH_N_Of_heating_events.append(f["/PartType5/NumberOfHeatingEvents"][:])
            BH_energy.append(
                f["/PartType5/AGNTotalInjectedEnergies"][:] * unit_energy_in_cgs / 1e51
            )

            # Minimum time-step of the simulation [in code units]
            BH_AccretionLimitedTimestep.append(
                f["/PartType5/AccretionLimitedTimeSteps"][:]
                * unit_time_in_cgs
                / constants["YEAR_IN_CGS"]
                / 1e6
            )

            BH_Timestep.append(
                time_step
                * np.power(2, f["/PartType5/TimeBins"][0] - time_bin)
                * unit_time_in_cgs
                / constants["YEAR_IN_CGS"]
                / 1e6
            )


def plot():

    with h5.File("output_0000.hdf5", "r") as f:
        feedback_scheme = str(
            f["/Parameters"].attrs["COLIBREAGN:AGN_feedback_model"],
            encoding="utf-8",
            errors="ignore",
        )
        deterministic = str(
            f["/Parameters"].attrs["COLIBREAGN:AGN_use_deterministic_feedback"],
            encoding="utf-8",
            errors="ignore",
        )
        N_heat = float(f["/Parameters"].attrs["COLIBREAGN:AGN_num_ngb_to_heat"])
        dT = float(f["/Parameters"].attrs["COLIBREAGN:AGN_delta_T_K"])
        eta = float(f["/Parameters"].attrs["COLIBREAGN:coupling_efficiency"])

        # Compute heating energy
        unit_mass_in_cgs = f["/Units"].attrs["Unit mass in cgs (U_M)"]
        M_gas = np.mean(f["/PartType0/Masses"][:]) * unit_mass_in_cgs[0]
        gamma = 5.0 / 3.0
        mu = 0.6  # fully ionised gas
        E_heat = (
            constants["BOLTZMANN_IN_CGS"]
            * dT
            / (gamma - 1.0)
            * M_gas
            / mu
            / constants["PROTON_MASS_IN_CGS"]
        )

        print("Feedback model: ", feedback_scheme)
        print("Deterministic?: ", deterministic)
        print("N_heat: {:.1f}".format(N_heat))
        print("T_heat: {:.3e} K".format(dT))
        print("E_heat: {:.3e} erg".format(E_heat))

    # Plotting
    plt.rcParams["ytick.direction"] = "in"
    plt.rcParams["xtick.direction"] = "in"
    plt.rc("text", usetex=True)
    plt.rcParams["axes.linewidth"] = 2
    plt.rc("font", family="serif")
    plt.rcParams["figure.figsize"] = 10, 10
    fig, ax = plt.subplots(2, 2)

    plt.suptitle(
        "AGN feedback model: "
        + feedback_scheme
        + "; Deterministic: "
        + deterministic
        + "; $N_{\\rm heat} =$ "
        + "{:.1f}".format(N_heat)
        + ";\n $T_{\\rm heat} =$ "
        + "{:.3e} K".format(dT)
        + "; coupling efficiency = "
        + "{:.1f}".format(eta),
        fontsize=20,
    )

    # plot 1
    ax[0, 0].plot(
        snp_times, gas_energy_from_AGN, lw=3, color="skyblue", label="Received by gas"
    )
    ax[0, 0].plot(
        snp_times, BH_energy, lw=2, color="k", dashes=(5, 5), label="Ejected by the BH"
    )

    ax[0, 0].xaxis.set_tick_params(labelsize=17)
    ax[0, 0].yaxis.set_tick_params(labelsize=17)
    x_minor_locator00 = AutoMinorLocator(5)
    y_minor_locator00 = AutoMinorLocator(5)
    ax[0, 0].xaxis.set_minor_locator(x_minor_locator00)
    ax[0, 0].yaxis.set_minor_locator(y_minor_locator00)
    ax[0, 0].tick_params(which="both", width=1.7)
    ax[0, 0].tick_params(which="major", length=9)
    ax[0, 0].tick_params(which="minor", length=5)

    ax[0, 0].set_xlabel("Time [Myr]", fontsize=18)
    ax[0, 0].set_ylabel("Energy [$10^{51}$ erg]", fontsize=18)

    ax[0, 0].legend(loc="upper left", fontsize=15, frameon=False)

    # plot 2
    ax[0, 1].plot(snp_times, BH_energy_reservoirs, lw=2, color="skyblue")
    ax[0, 1].axhline(
        E_heat * N_heat / 1e51,
        color="k",
        dashes=(3, 3),
        lw=1.5,
        label="$E_{\\rm heat} \\times N_{\\rm heat}$",
    )

    ax[0, 1].xaxis.set_tick_params(labelsize=17)
    ax[0, 1].yaxis.set_tick_params(labelsize=17)
    x_minor_locator01 = AutoMinorLocator(5)
    y_minor_locator01 = AutoMinorLocator(5)
    ax[0, 1].xaxis.set_minor_locator(x_minor_locator01)
    ax[0, 1].yaxis.set_minor_locator(y_minor_locator01)
    ax[0, 1].tick_params(which="both", width=1.7)
    ax[0, 1].tick_params(which="major", length=9)
    ax[0, 1].tick_params(which="minor", length=5)

    ax[0, 1].set_xlabel("Time [Myr]", fontsize=18)
    ax[0, 1].set_ylabel("Energy reservoir [$10^{51}$ erg]", fontsize=18)

    ax[0, 1].legend(loc="lower right", fontsize=15)

    # plot 3
    ax[1, 0].plot(
        snp_times,
        BH_AccretionLimitedTimestep,
        lw=3,
        color="skyblue",
        label="Limiter $\\propto$ accretion rate",
    )
    ax[1, 0].plot(
        snp_times,
        BH_Timestep,
        lw=1.5,
        color="k",
        dashes=(3, 3),
        label="BH actual time-step",
    )

    ax[1, 0].xaxis.set_tick_params(labelsize=17)
    ax[1, 0].yaxis.set_tick_params(labelsize=17)
    x_minor_locator10 = AutoMinorLocator(5)
    y_minor_locator10 = AutoMinorLocator(5)
    ax[1, 0].xaxis.set_minor_locator(x_minor_locator10)
    ax[1, 0].yaxis.set_minor_locator(y_minor_locator10)
    ax[1, 0].tick_params(which="both", width=1.7)
    ax[1, 0].tick_params(which="major", length=9)
    ax[1, 0].tick_params(which="minor", length=5)

    ax[1, 0].set_xlabel("Time [Myr]", fontsize=18)
    ax[1, 0].set_ylabel("Time-step [Myr]", fontsize=18)

    ax[1, 0].legend(loc="upper right", fontsize=13)

    ax[1, 0].set_yscale("log")
    fixlogax(ax[1, 0], "y")

    # plot 4
    ax[1, 1].axhline(y=N_heat, color="grey", alpha=0.25, lw=1.5)
    ax[1, 1].plot(
        snp_times, BH_N_Of_heating_events, lw=3, color="skyblue", label="Heating events"
    )
    ax[1, 1].plot(
        snp_times,
        BH_N_Of_AGN_events,
        lw=2,
        color="k",
        dashes=(5, 5),
        label="AGN events",
    )

    with_events = np.where(np.array(BH_N_Of_AGN_events) > 0)
    ax[1, 1].plot(
        np.array(snp_times)[with_events],
        np.array(BH_N_Of_heating_events)[with_events]
        / np.array(BH_N_Of_AGN_events)[with_events],
        lw=3,
        color="orange",
        dashes=(2, 2),
        label="Heating / AGN",
    )

    ax[1, 1].xaxis.set_tick_params(labelsize=17)
    ax[1, 1].yaxis.set_tick_params(labelsize=17)
    x_minor_locator11 = AutoMinorLocator(5)
    y_minor_locator11 = AutoMinorLocator(5)
    ax[1, 1].xaxis.set_minor_locator(x_minor_locator11)
    ax[1, 1].yaxis.set_minor_locator(y_minor_locator11)
    ax[1, 1].tick_params(which="both", width=1.7)
    ax[1, 1].tick_params(which="major", length=9)
    ax[1, 1].tick_params(which="minor", length=5)

    ax[1, 1].set_xlabel("Time $t$ [Myr]", fontsize=18)
    ax[1, 1].set_ylabel("Number of events ($>t$)", fontsize=18)

    ax[1, 1].legend(loc="lower right", fontsize=14, frameon=False)

    ax[1, 1].set_yscale("log")
    fixlogax(ax[1, 1], "y")

    plt.savefig("AGN_feedback_statistics.pdf", bbox_inches="tight", pad_inches=0.1)
    plt.close()


if __name__ == "__main__":

    N_snapshot = int(sys.argv[1])
    print("Number of snapshots to read {}".format(N_snapshot))
    load_data(N_snapshot)
    plot()
