#!/usr/bin/env python
#
# Usage:
#  python3 plot_scaling_results_detailed.py input-file-1 input-file-2
#
# Description:
# Plots speed up, parallel efficiency and time to solution given the output file generated by SWIFT.
# You need to run SWIFT with -v 1.
#
# Example:
# python3 plot_scaling_results_detailed.py output_1.log output_2.log

import sys
import re
import numpy as np
import matplotlib.pyplot as plt

from timed_functions import labels

min_fraction = 2e-2

legendTitle = " "

hexcols = [
    "#332288",
    "#88CCEE",
    "#44AA99",
    "#117733",
    "#999933",
    "#DDCC77",
    "#CC6677",
    "#882255",
    "#AA4499",
    "#661100",
    "#6699CC",
    "#AA4466",
    "#4477AA",
]
colors = [
    hexcols[0],
    hexcols[1],
    hexcols[3],
    hexcols[5],
    hexcols[6],
    hexcols[8],
    hexcols[2],
    hexcols[4],
    hexcols[7],
]

# Work out how many data series there are
if len(sys.argv) == 1:
    print("Please specify an input file in the arguments.")
    sys.exit()
else:
    filenames = sys.argv[1:]

# Parse file and return total time taken, speed up and parallel efficiency
def parse_files():

    # Allocate the arrays
    n_labels = len(labels)
    n_files = len(filenames)

    total_time = np.zeros((n_labels, n_files))
    threads = np.zeros(n_files)

    for i, filename in enumerate(filenames):  # Loop over each file
        print("Files read %.1f%%\r" % (100 * i / n_files), end="")
        with open(filename, "r") as f:

            # Search the different phrases
            for line in f:

                # Extract the number of threads
                if "threads / rank and" in line:
                    all_numbers = re.findall(r"[+-]?((\d+\.?\d*)|(\.\d+))", line)
                    if len(all_numbers) != 12:
                        raise Exception("Failed to read the following line", line)
                    rank = int(all_numbers[5][0])
                    thread = int(all_numbers[6][0])
                    threads[i] = rank * thread

                # Loop over the possbile labels
                for j in range(n_labels):

                    # Extract the different blocks
                    if re.search("%s took" % labels[j][0], line):

                        total_time[j, i] += float(
                            re.findall(r"[+-]?((\d+\.?\d*)|(\.\d+))", line)[-1][0]
                        )

                # Find the last line with meaningful output (avoid crash report, batch system stuff....)
                if re.findall(r"\[[0-9]{4}\][ ]\[*", line) or re.findall(
                    r"^\[[0-9]*[.][0-9]+\][ ]", line
                ):
                    lastline = line

    return threads, total_time


def cleanup_data(threads, total_time):
    n_labels = len(labels)
    n_files = len(filenames)

    # Remove the functions not found
    time = np.sum(total_time, axis=1)
    ind = time == 0.0
    total_time = np.delete(total_time, ind, axis=0)
    for i in range(n_labels)[::-1]:
        if ind[i]:
            del labels[i]
    n_labels = len(labels)

    # Get the elements representing a large fraction of the time
    frac = total_time / np.sum(total_time, axis=0)
    ind = np.sum(frac > min_fraction, axis=1)
    ind = ind > 0
    print("Grouping: ", np.array(labels)[~ind, 0])

    # Group the previous elements together
    remaining = np.zeros((1, n_files))
    for i in range(n_labels)[::-1]:
        if not ind[i]:
            remaining += total_time[i, :]
            del labels[i]

    remaining_frac = remaining / np.sum(total_time, axis=0)
    total_time = np.delete(total_time, ~ind, axis=0)
    n_labels = len(labels)

    # Add the other group if required
    if np.sum(remaining_frac > 0.5 * min_fraction) > 0:
        labels.append(("Others", -1))
        total_time = np.append(total_time, remaining, axis=0)
        n_labels = len(labels)

    # Sort according to the threads number
    ind = np.argsort(threads)
    threads = threads[ind]
    for i in range(n_labels):
        total_time[i, :] = total_time[i, ind]

    print("\nNumber of threads found", threads)

    # Avoid division by 0
    total_time[total_time == 0.0] = 1e-6

    # Find speed-up and parallel efficiency
    speed_up = total_time[:, 0][:, np.newaxis] / total_time
    parallel_eff = speed_up / threads

    return threads, total_time, speed_up, parallel_eff


def plot_results(threads, total_time, speed_up, parallel_eff):
    n_files = len(filenames)
    n_labels = len(labels)

    fig, axarr = plt.subplots(2, 2, figsize=(10, 10), frameon=True)
    speed_up_plot = axarr[0, 0]
    parallel_eff_plot = axarr[0, 1]
    total_time_plot = axarr[1, 0]
    empty_plot = axarr[1, 1]

    # Plot speed up
    speed_up_plot.plot(threads, threads, linestyle="--", lw=1.5, color="0.2")
    for i in range(n_labels):
        i_color = i % len(colors)
        speed_up_plot.plot(threads, speed_up[i, :], c=colors[i_color])

    speed_up_plot.set_ylabel("Speed up", labelpad=0.0)
    speed_up_plot.set_xlabel("Threads", labelpad=0.0)
    speed_up_plot.set_xlim([0.7, threads.max() + 1])
    speed_up_plot.set_ylim([0.7, threads.max() + 1])

    # Plot parallel efficiency
    for i in range(n_labels):
        i_color = i % len(colors)
        parallel_eff_plot.plot(threads, parallel_eff[i, :], c=colors[i_color])

    parallel_eff_plot.set_xscale("log")
    parallel_eff_plot.set_ylabel("Parallel efficiency", labelpad=0.0)
    parallel_eff_plot.set_xlabel("Threads", labelpad=0.0)
    parallel_eff_plot.set_ylim([0, 1.1])
    parallel_eff_plot.set_xlim([0.9, 10 ** (np.floor(np.log10(threads.max())) + 0.5)])

    # Plot time to solution
    pts = np.array([1, 10 ** np.floor(np.log10(threads.max()) + 1)])
    for i in range(n_labels):
        i_color = i % len(colors)
        label = labels[i][0]
        # Data
        total_time_plot.loglog(
            threads, total_time[i, :], c=colors[i_color], label=label
        )
        # Perfect scaling
        total_time_plot.loglog(
            pts, total_time[i, 0] / pts, "--", c=colors[i_color], lw=1.0
        )

    y_min = 10 ** np.floor(np.log10(total_time.min() * 0.6))
    y_max = 1.0 * 10 ** np.floor(np.log10(total_time.max() * 1.5) + 1)
    total_time_plot.set_xscale("log")
    total_time_plot.set_xlabel("Threads", labelpad=0.0)
    total_time_plot.set_ylabel("Time to solution", labelpad=0.0)
    total_time_plot.set_xlim([0.9, 10 ** (np.floor(np.log10(threads.max())) + 0.5)])
    total_time_plot.set_ylim(y_min, y_max)

    total_time_plot.legend(
        bbox_to_anchor=(1.21, 0.97),
        loc=2,
        ncol=2,
        borderaxespad=0.0,
        prop={"size": 12},
        frameon=False,
    )
    empty_plot.axis("off")

    fig.suptitle("Speed up, parallel efficiency and time to solution", fontsize=16)

    return


# Calculate results
threads, total_time = parse_files()
threads, total_time, speed_up, parallel_eff = cleanup_data(threads, total_time)

print("Functions found: ", [l[0] for l in labels])

plot_results(threads, total_time, speed_up, parallel_eff)
# And plot
plt.show()
