import numpy as np

import matplotlib.pylab as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import cm as cmm

import h5py as h5
import sphviewer as sph

import sys

# Physical constants
constants = {
    "NEWTON_GRAVITY_CGS": 6.67408e-8,
    "SOLAR_MASS_IN_CGS": 1.98848e33,
    "PARSEC_IN_CGS": 3.08567758e18,
    "PROTON_MASS_IN_CGS": 1.672621898e-24,
    "BOLTZMANN_IN_CGS": 1.38064852e-16,
    "YEAR_IN_CGS": 3.15569252e7,
}


def plot(
    snp: int = 0,
    z_aver_width: float = 0.3,
    gamma: float = 1.936492,
    X_H: float = 0.73738788833,
    feedback_delay: float = 1.0,
) -> None:
    """
    :param snp: Snapshot number in the output
    :param z_aver_width: the width of the slice [kpc]
    :param gamma: the ratio between the kernel size and kernel smoothing length
    :param X_H: the mass function of hydrogen
    :param feedback_delay: Highlight particles in the plot that were heated by the BH within the last feedback_delay Myr
    :return: None
    """

    with h5.File("./output_{:04d}.hdf5".format(snp), "r") as f:

        # Box & simulation properties
        boxsize = f["/Header"].attrs["BoxSize"]
        unit_length_in_cgs = f["/Units"].attrs["Unit length in cgs (U_L)"]
        unit_mass_in_cgs = f["/Units"].attrs["Unit mass in cgs (U_M)"]
        unit_time_in_cgs = f["/Units"].attrs["Unit time in cgs (U_t)"]
        unit_energy_in_cgs = (
            unit_mass_in_cgs * (unit_length_in_cgs / unit_time_in_cgs) ** 2
        )
        centre = boxsize / 2.0
        time = (
            f["/Header"].attrs["Time"]
            * unit_time_in_cgs
            / constants["YEAR_IN_CGS"]
            / 1e6
        )
        feedback_scheme = str(
            f["/Parameters"].attrs["COLIBREAGN:AGN_feedback_model"],
            encoding="utf-8",
            errors="ignore",
        )

        # Gas properties
        gas_pos = f["/PartType0/Coordinates"][:, :]
        gas_T = f["/PartType0/Temperatures"][:]
        gas_pressure = (
            f["/PartType0/Pressures"][:]
            * unit_mass_in_cgs
            / unit_time_in_cgs ** 2
            / unit_length_in_cgs
        )
        gas_mass = f["/PartType0/Masses"][:]
        gas_hsml = f["/PartType0/SmoothingLengths"][:]
        gas_LastAGNtimes = (
            f["/PartType0/LastAGNFeedbackTimes"][:]
            * unit_time_in_cgs
            / constants["YEAR_IN_CGS"]
            / 1e6
        )

        # Shift with respect to the centre
        gas_pos[:, 0] -= centre[0]
        gas_pos[:, 1] -= centre[1]
        gas_pos[:, 2] -= centre[2]

        # BH properties
        BH_pos = f["/PartType5/Coordinates"][:, :]
        BH_kernel = f["/PartType5/SmoothingLengths"][:] * gamma
        BH_energy = f["/PartType5/EnergyReservoirs"][:] * unit_energy_in_cgs / 1e51
        N_AGN = f["/PartType5/NumberOfAGNEvents"][0]
        N_heating = f["/PartType5/NumberOfHeatingEvents"][0]

        # Shift with respect to the centre
        BH_pos[:, 0] -= centre[0]
        BH_pos[:, 1] -= centre[1]
        BH_pos[:, 2] -= centre[2]

        # Time to find BH's Ngbs
        gas_pos[:, 0] -= BH_pos[0, 0]
        gas_pos[:, 1] -= BH_pos[0, 1]
        gas_pos[:, 2] -= BH_pos[0, 2]
        separations = np.sqrt(np.sum(gas_pos * gas_pos, axis=1))
        inside_kernel = np.where(separations < BH_kernel[0])
        gas_pos[:, 0] += BH_pos[0, 0]
        gas_pos[:, 1] += BH_pos[0, 1]
        gas_pos[:, 2] += BH_pos[0, 2]

        # Slice
        arg = np.where(np.abs(gas_pos[:, 2] - BH_pos[0, 2]) < z_aver_width / 2.0)

        # Standard way to look at the box
        Camera = sph.Camera(
            r="infinity",
            t=0,
            p=0,
            roll=0,
            xsize=500,
            ysize=500,
            x=0.0,
            y=0.0,
            z=0.0,
            extent=[
                -boxsize[0] / 2.0,
                boxsize[0] / 2.0,
                -boxsize[1] / 2.0,
                boxsize[1] / 2.0,
            ],
        )

        # Scene 1
        Particles = sph.Particles(gas_pos[arg], gas_mass[arg], gas_hsml[arg] * gamma)
        Scene = sph.Scene(Particles, Camera)
        Render = sph.Render(Scene)
        extent = Render.get_extent()
        density = Render.get_image()
        const = (
            X_H
            / z_aver_width
            * unit_mass_in_cgs
            / unit_length_in_cgs ** 3
            / constants["PROTON_MASS_IN_CGS"]
        )

        # Scene 2
        Particles = sph.Particles(
            gas_pos[arg], gas_mass[arg] * gas_T[arg], gas_hsml[arg] * gamma
        )
        Scene = sph.Scene(Particles, Camera)
        Render = sph.Render(Scene)
        temperature = Render.get_image() / density

        # Scene 3
        Particles = sph.Particles(
            gas_pos[arg], gas_mass[arg] * gas_pressure[arg], gas_hsml[arg] * gamma
        )
        Scene = sph.Scene(Particles, Camera)
        Render = sph.Render(Scene)
        pressure = Render.get_image() / (density * constants["BOLTZMANN_IN_CGS"])

        # Plotting
        plt.rcParams["ytick.direction"] = "in"
        plt.rcParams["xtick.direction"] = "in"
        plt.rc("text", usetex=True)
        plt.rcParams["axes.linewidth"] = 2
        plt.rc("font", family="serif")
        plt.rcParams["figure.figsize"] = 10, 10
        fig, ax = plt.subplots(2, 2)

        # Image 1
        im = ax[0, 0].imshow(
            np.log10(density * const),
            extent=[
                extent[0] + Scene.Camera.get_params()["x"],
                extent[1] + Scene.Camera.get_params()["x"],
                extent[2] + Scene.Camera.get_params()["y"],
                extent[3] + Scene.Camera.get_params()["y"],
            ],
            origin="lower",
            cmap=cmm.coolwarm,
            vmin=-2,
            vmax=1,
        )

        divider = make_axes_locatable(ax[0, 0])
        cax = divider.append_axes("top", size="5%", pad=0.05)
        ax[0, 0].set_xticks([])
        ax[0, 0].set_yticks([])
        cbar = plt.colorbar(
            im,
            cax=cax,
            ticks=[-2, -1, 0, 1],
            orientation="horizontal",
            use_gridspec=True,
        )
        cax.xaxis.set_ticks_position("top")
        cax.xaxis.set_ticklabels(["$-2.0$", "$-1.0$", "$0.0$", "$1.0$"], fontsize=18)
        cax.set_xlabel(
            "$\\mathrm{log}\\, n_{\\rm H}$ [cm$^{-3}$]", fontsize=22, labelpad=-55
        )

        # BH
        ax[0, 0].scatter(BH_pos[0, 0], BH_pos[0, 1], s=80, color="k")
        circle = plt.Circle(
            (BH_pos[0, 0], BH_pos[0, 1]),
            BH_kernel,
            facecolor="none",
            lw=1.5,
            edgecolor="black",
        )
        ax[0, 0].add_artist(circle)

        ax[0, 0].text(
            -0.5 * boxsize[0] * 0.91,
            0.5 * boxsize[0] * 0.80,
            "$\\rm E_{reservoir}$\\,"
            + "$ = {:.2f} \\times $".format(BH_energy[0])
            + "$10^{51} \\rm \\, erg$",
            color="black",
            fontsize=18,
            bbox=dict(
                facecolor="orange", edgecolor="black", boxstyle="round", alpha=0.8
            ),
        )

        ax[0, 0].text(
            -0.5 * boxsize[0] * 0.91,
            -0.5 * boxsize[0] * 0.90,
            "time = {:.2f} Myr".format(time[0]),
            color="black",
            fontsize=20,
            bbox=dict(
                facecolor="orange", edgecolor="black", boxstyle="round", alpha=0.8
            ),
        )

        # Image 2
        im2 = ax[0, 1].imshow(
            np.log10(temperature),
            extent=[
                extent[0] + Scene.Camera.get_params()["x"],
                extent[1] + Scene.Camera.get_params()["x"],
                extent[2] + Scene.Camera.get_params()["y"],
                extent[3] + Scene.Camera.get_params()["y"],
            ],
            origin="lower",
            cmap=cmm.coolwarm,
            vmin=4,
            vmax=7,
        )

        divider = make_axes_locatable(ax[0, 1])
        cax1 = divider.append_axes("top", size="5%", pad=0.05)
        ax[0, 1].set_xticks([])
        ax[0, 1].set_yticks([])

        # Colourbar
        cbar1 = plt.colorbar(
            im2,
            cax=cax1,
            ticks=[4, 5, 6, 7],
            orientation="horizontal",
            use_gridspec=True,
        )
        cax1.xaxis.set_ticks_position("top")
        cax1.xaxis.set_ticklabels(["$4.0$", "$5.0$", "$6.0$", "$7.0$"], fontsize=18)
        cax1.set_xlabel("$\\mathrm{log}\\, T_{\\rm K}$ [K]", fontsize=22, labelpad=-55)

        # BH
        ax[0, 1].scatter(BH_pos[0, 0], BH_pos[0, 1], s=80, color="k")
        circle = plt.Circle(
            (BH_pos[0, 0], BH_pos[0, 1]),
            BH_kernel,
            facecolor="none",
            lw=1.5,
            edgecolor="black",
        )
        ax[0, 1].add_artist(circle)

        # Text
        ax[0, 1].text(
            -0.5 * boxsize[0] * 0.91,
            0.5 * boxsize[0] * 0.80,
            "N of AGN events: {:d}".format(N_AGN),
            color="black",
            fontsize=20,
            bbox=dict(
                facecolor="orange", edgecolor="black", boxstyle="round", alpha=0.8
            ),
        )
        ax[0, 1].text(
            -0.5 * boxsize[0] * 0.91,
            -0.5 * boxsize[0] * 0.90,
            "N of heating events: {:d}".format(N_heating),
            color="black",
            fontsize=20,
            bbox=dict(
                facecolor="orange", edgecolor="black", boxstyle="round", alpha=0.8
            ),
        )

        # Image 3
        im = ax[1, 0].imshow(
            np.log10(pressure),
            extent=[
                extent[0] + Scene.Camera.get_params()["x"],
                extent[1] + Scene.Camera.get_params()["x"],
                extent[2] + Scene.Camera.get_params()["y"],
                extent[3] + Scene.Camera.get_params()["y"],
            ],
            origin="lower",
            cmap=cmm.inferno,
            vmin=3,
            vmax=6,
        )

        divider = make_axes_locatable(ax[1, 0])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        ax[1, 0].set_xticks([])
        ax[1, 0].set_yticks([])

        # Colourbar
        cbar = plt.colorbar(
            im, cax=cax, ticks=[3, 4, 5, 6], orientation="horizontal", use_gridspec=True
        )
        cax.xaxis.set_ticks_position("bottom")
        cax.xaxis.set_ticklabels(["$3$", "$4$", "$5$", "$6$"], fontsize=18)
        cax.set_xlabel(
            "$\\mathrm{log}\\, P/ \\mathrm{k_B}$ [K cm$^{-3}$]",
            fontsize=22,
            labelpad=10,
        )

        # How big are 200 pc in the plot?
        ax[1, 0].plot(
            [-0.5 * boxsize[0] * 0.91, -0.5 * boxsize[0] * 0.91 + 200.0 / 1e3],
            [-0.5 * boxsize[0] * 0.95, -0.5 * boxsize[0] * 0.95],
            lw=3,
            color="white",
        )
        ax[1, 0].text(
            -0.5 * boxsize[0] * 0.91,
            -0.5 * boxsize[0] * 0.90,
            "$200$ pc",
            fontsize=20,
            color="white",
        )

        # BH
        ax[1, 0].scatter(BH_pos[0, 0], BH_pos[0, 1], s=80, color="k")
        circle = plt.Circle(
            (BH_pos[0, 0], BH_pos[0, 1]),
            BH_kernel,
            facecolor="none",
            lw=1.5,
            edgecolor="black",
        )
        ax[1, 0].add_artist(circle)

        # Text
        ax[1, 0].text(
            -0.5 * boxsize[0] * 0.91,
            0.5 * boxsize[0] * 0.80,
            "AGN model: " + feedback_scheme,
            color="black",
            fontsize=18,
            bbox=dict(
                facecolor="orange", edgecolor="black", boxstyle="round", alpha=0.8
            ),
        )

        # Image4
        ax[1, 1].set_xlim(
            BH_pos[0, 0] - BH_kernel[0] * 1.5, BH_pos[0, 0] + BH_kernel[0] * 1.5
        )
        ax[1, 1].set_ylim(
            BH_pos[0, 1] - BH_kernel[0] * 1.5, BH_pos[0, 1] + BH_kernel[0] * 1.5
        )

        # Indicate that we are plotting a zoom-in region
        ax[1, 1].text(
            0.95,
            0.95,
            "Zoom-in",
            color="black",
            fontsize=20,
            bbox=dict(
                facecolor="orange", edgecolor="black", boxstyle="round", alpha=0.8
            ),
            ha="right",
            va="top",
            transform=ax[1, 1].transAxes,
        )

        # Plot BH gas Ngbs

        # Cold particles
        cold = np.where(
            (np.abs(gas_LastAGNtimes[inside_kernel] - time) > feedback_delay)
            | (gas_LastAGNtimes[inside_kernel] == -1.0)
        )

        # Recently heated particles
        hot = np.where(
            (np.abs(gas_LastAGNtimes - time) <= feedback_delay)
            & (gas_LastAGNtimes != -1.0)
        )

        im = ax[1, 1].scatter(
            gas_pos[inside_kernel[0][cold], 0],
            gas_pos[inside_kernel[0][cold], 1],
            c=np.log10(gas_T[inside_kernel][cold]),
            s=20,
            cmap=cmm.viridis,
            vmin=4,
            vmax=6,
        )

        ax[1, 1].scatter(
            gas_pos[hot[0], 0],
            gas_pos[hot[0], 1],
            c=np.log10(gas_T[hot]),
            s=100,
            cmap=cmm.viridis,
            vmin=4,
            vmax=6,
            edgecolor="k",
            lw=1,
            zorder=5,
        )

        for ray in range(len(hot[0])):
            ax[1, 1].plot(
                np.array([BH_pos[0, 0], gas_pos[hot[0], 0][ray]]),
                np.array([BH_pos[0, 1], gas_pos[hot[0], 1][ray]]),
                color="black",
                dashes=(3, 3),
                zorder=2,
            )
        # Legend
        ax[1, 1].scatter(
            1e6,
            1e6,
            s=100,
            color="yellow",
            edgecolor="k",
            lw=1,
            zorder=5,
            label="heated within the last {:.1f} Myr".format(feedback_delay),
        )
        ax[1, 1].legend(loc="lower left", fontsize=16, frameon=False, scatterpoints=2)

        # Colourbar
        divider = make_axes_locatable(ax[1, 1])
        cax = divider.append_axes("bottom", size="5%", pad=0.05)
        ax[1, 1].set_xticks([])
        ax[1, 1].set_yticks([])
        cbar = plt.colorbar(
            im, cax=cax, orientation="horizontal", use_gridspec=True, ticks=[4, 5, 6]
        )
        cax.xaxis.set_ticks_position("bottom")
        cax.xaxis.set_ticklabels(["$4$", "$5$", "$6$"], fontsize=18)
        cax.set_xlabel("log $T_{\\rm BH \\, Ngbs}$", fontsize=22, labelpad=10)

        # Plot BH
        ax[1, 1].scatter(BH_pos[0, 0], BH_pos[0, 1], s=150, color="k", zorder=3)
        circle = plt.Circle(
            (BH_pos[0, 0], BH_pos[0, 1]), BH_kernel, color="blue", alpha=0.2
        )
        ax[1, 1].add_artist(circle)

        # This subplot must be of the same size as the other three
        ratio_default = (ax[1, 0].get_xlim()[1] - ax[1, 0].get_xlim()[0]) / (
            ax[1, 0].get_ylim()[1] - ax[1, 0].get_ylim()[0]
        )
        ax[1, 1].set_aspect(ratio_default)

        plt.tight_layout()

        plt.savefig(
            "feedback_{:04d}.png".format(snp), bbox_inches="tight", pad_inches=0.1
        )
        plt.close()


if __name__ == "__main__":

    snapshot = int(sys.argv[1])
    print("Current snapshot number is {}".format(snapshot))
    plot(snapshot)
