#!/usr/bin/env python3
description = """
Plot the number of tasks for each depth level and each type of task.

Usage:
  ./plot_task_level.py task_level_0.txt
  or
  ./plot_task_level.py task_level_*_0.txt

"""


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import sys
import argparse
from os import path


def parse_args():
    """
    Parses command line arguments.

    Returns
    -------

    args: Namespace
        Namespace returned by argparse.ArgumentParser.parse_args()
        containing all arguments

    files:
        List of files parsed from the command line.

    Raises
    ------

    FileNotFoundError
        If the required filename provided on the command line doesn't exist
    """

    # description is string at the top of this file.
    parser = argparse.ArgumentParser(description=description)

    parser.add_argument(
        "-c",
        "--count",
        dest="count_levels",
        help="count on how many different levels the tasks can be found"
        " and add it to the label",
        action="store_true",
    )

    parser.add_argument(
        "-d",
        "--displace",
        dest="displace",
        help="attempt to displace overlapping point on the plot a bit"
        " and try to make them visible",
        action="store_true",
    )

    parser.add_argument(
        "file",
        type=str,
        nargs="+",
        help="Required file name(s) of .txt file(s) of the task levels "
        "generated by swift.",
    )

    args = parser.parse_args()
    files = args.file

    for f in files:
        if not path.exists(f):
            raise FileNotFoundError("File not found:'" + f + "'")

    return args, files


def read_data(files):
    """
    Reads in data from the .txt file.

    Parameters
    ----------

    files: list
        list of filenames to be read from


    Returns
    -------

    data: pandas dataframe
        dataframe containing read in data
    """

    # Column names
    names = ["type", "subtype", "depth", "count"]

    alldata = None
    for f in files:
        # read file
        data = pd.read_csv(f, sep=" ", comment="#", names=names)
        if alldata is None:
            alldata = data
        else:
            concat = pd.concat([alldata, data])
            alldata = concat.groupby(["type", "subtype", "depth"], as_index=False).sum()

    return data


def get_discrete_cmap(nentries):
    """
    Returns a discrete colormap.

    Parameters
    ----------

    nentries: int
        how many entries you want for your colormap


    Returns
    -------

    cmap: list
        list of colors


    Raises
    ------

    IndexError:
        When you want more entries than there are available.
        Current maximum is 21.
    """

    fullcolorlist = [
        "red",
        "green",
        "blue",
        "gold",
        "magenta",
        "cyan",
        "lime",
        "saddlebrown",
        "darkolivegreen",
        "cornflowerblue",
        "orange",
        "dimgrey",
        "navajowhite",
        "darkslategray",
        "mediumpurple",
        "lightpink",
        "mediumseagreen",
        "maroon",
        "midnightblue",
        "silver",
        "black",
    ]

    if nentries >= len(fullcolorlist) - 1:
        raise IndexError(
            "I can't handle more than 21 different colors."
            "Add more manually in get_discrete_cmap() function"
        )

    return fullcolorlist[: nentries + 1]


def add_levelcounts(data):
    """
    Adds a column to the data with the counts on how many levels a given task
    is executed on.

    Parameters
    ----------

    data: pandas dataframe
        The dataframe to use

    Returns
    -------

    data: pandas dataframe
        the modified dataframe
    """

    # add new column
    data["nlevels"] = ["1" for _ in range(data.shape[0])]

    # count on how many levels each task exists
    for i, (ttype, tsubtype) in enumerate(zip(data["type"], data["subtype"])):
        istype = data["type"] == ttype
        issubtype = data["subtype"] == tsubtype
        isthis = np.logical_and(istype, issubtype)
        count = np.count_nonzero(isthis)
        data.at[i, "nlevels"] = str(count)

    return data


def add_displacement(data):
    """
    Add small displacements to the task number counts and try
    to make them better visible

    Parameters
    ----------

    data: pandas dataframe
        the data to be modified


    Returns
    -------

    data: pandas dataframe
        the modified data
    """
    # add new column
    data["yvals"] = data["count"] * 1.0
    data["yval_modified"] = False
    inds = np.arange(0, data.shape[0])

    # count on how many levels each task exists
    for i, (ttype, tsubtype) in enumerate(zip(data["type"], data["subtype"])):
        if data["yval_modified"][i]:
            continue
        istype = data["type"] == ttype
        issubtype = data["subtype"] == tsubtype
        isthis = np.logical_and(istype, issubtype)
        uniques = np.unique(data["count"][isthis])
        for u in uniques:
            occurances = np.count_nonzero(data["count"][isthis] == u)
            for o in range(occurances):
                thisind = inds[isthis][o]
                step = 0.05 * max(int(np.log10(u) + 0.5), 1) * o * (-1) ** o
                data.at[thisind, "yvals"] += step
                data.at[thisind, "yval_modified"] = True

    return data


if __name__ == "__main__":

    args, files = parse_args()

    data = read_data(files)
    cmap = get_discrete_cmap(data["depth"].max())

    # are we counting the levels?
    if args.count_levels:
        data = add_levelcounts(data)

    # are we displacing the particles on the y axis?
    if args.displace:
        data = add_displacement(data)

    # plot data
    for i in range(data["depth"].max() + 1):
        ind = data["depth"] == i
        label = "depth = {0:d}".format(i)

        if args.count_levels:
            xvals = (
                data["type"][ind]
                + "_"
                + data["subtype"][ind]
                + "["
                + data["nlevels"][ind]
                + "]"
            )
        else:
            xvals = data["type"][ind] + "_" + data["subtype"][ind]

        if args.displace:
            yvals = data["yvals"][ind]
        else:
            yvals = data["count"][ind]
        c = cmap[i]
        plt.plot(xvals, yvals, "o", label=label, color=c, alpha=0.7)

    # modify figure parameters and show it
    plt.gca().set_yscale("log")
    plt.xticks(rotation=90)
    plt.ylabel("Number of Tasks")
    plt.gcf().subplots_adjust(bottom=0.225)
    plt.legend()
    plt.grid()
    plt.show()