import numpy as np
import matplotlib.pyplot as plt
from analysis import get_fit_info, get_behaviour_values, calculate_percent_errors
from ModelFit import get_best_fit
from Baseline import BaselineModel, BaselineCellData
import Figure_constants as consts


def main():
    dir_path = "results/final_2/"
    fits_info = get_fit_info(dir_path)

    # cell_behaviour, model_behaviour = get_behaviour_values(fits_info)
    # behaviour_overview_pairs(cell_behaviour, model_behaviour)

    # errors = calculate_percent_errors(fits_info)
    # create_boxplots(errors)

    example_good_hist_fits(dir_path)


def example_good_hist_fits(dir_path):
    strong_bursty_cell = "2018-05-08-ac-invivo-1"
    bursty_cell = "2014-03-19-ad-invivo-1"
    non_bursty_cell = "2012-12-21-am-invivo-1"

    fig, axes = plt.subplots(1, 3, sharex="all", figsize=consts.FIG_SIZE_MEDIUM_WIDE)

    for i, cell in enumerate([non_bursty_cell, bursty_cell, strong_bursty_cell]):
        fit_dir = dir_path + cell + "/"
        fit = get_best_fit(fit_dir)

        cell_data = fit.get_cell_data()
        eodf = cell_data.get_eod_frequency()

        model = fit.get_model()
        baseline_model = BaselineModel(model, eodf, trials=5)

        model_isi = np.array(baseline_model.get_interspike_intervals()) * eodf
        cell_isi = BaselineCellData(cell_data).get_interspike_intervals() * eodf

        bins = np.arange(0, 0.025, 0.0001) * eodf
        axes[i].hist(cell_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_DATA)
        axes[i].hist(model_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_MODEL)
        axes[i].set_xlabel("ISI in EOD periods")
    axes[0].set_ylabel("Density")
    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5)
    fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + "example_good_isi_hist_fits.png", transparent=True)
    plt.close()


def create_boxplots(errors):

    labels = ["{}_n:{}".format(k, len(errors[k])) for k in sorted(errors.keys())]
    for k in sorted(errors.keys()):
        print("{}: median %-error: {:.2f}".format(k, np.median(errors[k])))
    y_values = [errors[k] for k in sorted(errors.keys())]

    plt.boxplot(y_values)
    plt.xticks(np.arange(1, len(y_values)+1, 1), labels, rotation=45)
    plt.tight_layout()
    plt.show()
    plt.close()


def behaviour_overview_pairs(cell_behaviour, model_behaviour):
    # behaviour_keys = ["Burstiness", "coefficient_of_variation", "serial_correlation",
    #                   "vector_strength", "f_inf_slope", "f_zero_slope", "baseline_frequency"]

    pairs = [("baseline_frequency", "vector_strength", "serial_correlation"),
             ("Burstiness", "coefficient_of_variation"),
             ("f_inf_slope", "f_zero_slope")]

    for pair in pairs:
        cell = []
        model = []
        for behaviour in pair:
            cell.append(cell_behaviour[behaviour])
            model.append(model_behaviour[behaviour])
        overview_pair(cell, model, pair)


def overview_pair(cell, model, titles):
    fig = plt.figure(figsize=(8, 6))

    columns = len(cell)

    # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
    # the size of the marginal axes and the main axes in both directions.
    # Also adjust the subplot parameters for a square plot.
    gs = fig.add_gridspec(2, columns, width_ratios=[5] * columns, height_ratios=[3, 7],
                          left=0.1, right=0.9, bottom=0.1, top=0.9,
                          wspace=0.2, hspace=0.05)

    for i in range(len(cell)):
        if titles[i] == "f_zero_slope":
            length_before = len(cell[i])
            idx = np.array(cell[i]) < 30000
            cell[i] = np.array(cell[i])[idx]
            model[i] = np.array(model[i])[idx]

            idx = np.array(model[i]) < 30000
            cell[i] = np.array(cell[i])[idx]
            model[i] = np.array(model[i])[idx]
            print("removed {} values from f_zero_slope plot.".format(length_before - len(cell[i])))
        ax = fig.add_subplot(gs[1, i])
        ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
        scatter_hist(cell[i], model[i], ax, ax_histx, titles[i])

    # plt.tight_layout()
    plt.show()


def grouped_error_overview_behaviour_dist(cell_behaviours, model_behaviours):
    # start with a square Figure
    fig = plt.figure(figsize=(12, 12))

    rows = 4
    columns = 2
    # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
    # the size of the marginal axes and the main axes in both directions.
    # Also adjust the subplot parameters for a square plot.
    gs = fig.add_gridspec(rows*2, columns, width_ratios=[5]*columns, height_ratios=[3, 7] * rows,
                          left=0.1, right=0.9, bottom=0.1, top=0.9,
                          wspace=0.2, hspace=0.5)

    for i, behaviour in enumerate(sorted(cell_behaviours.keys())):
        col = int(np.floor(i / rows))
        row = i - rows*col
        ax = fig.add_subplot(gs[row*2 + 1, col])
        ax_histx = fig.add_subplot(gs[row*2, col])

        # use the previously defined function
        scatter_hist(cell_behaviours[behaviour], model_behaviours[behaviour], ax, ax_histx, behaviour)

    plt.tight_layout()
    plt.show()


def scatter_hist(cell_values, model_values, ax, ax_histx, behaviour, ax_histy=None):
    # copied from matplotlib

    # no labels
    ax_histx.tick_params(axis="cell", labelbottom=False)
    # ax_histy.tick_params(axis="model_values", labelleft=False)
    # the scatter plot:
    ax.scatter(cell_values, model_values)

    minimum = min(min(cell_values), min(model_values))
    maximum = max(max(cell_values), max(model_values))
    ax.plot((minimum, maximum), (minimum, maximum), color="grey")

    ax.set_xlabel("cell")
    ax.set_ylabel("model")

    ax_histx.hist(cell_values, color="blue", alpha=0.5)
    ax_histx.hist(model_values, color="orange", alpha=0.5)
    ax_labels = ax.get_xticklabels()
    ax_histx.set_xticklabels([])
    ax.set_xticklabels(ax_labels)
    ax_histx.set_xticks(ax.get_xticks())
    ax_histx.set_xlim(ax.get_xlim())
    ax_histx.set_title(behaviour)

    # ax_histy.hist(y, bins=bins, orientation='horizontal')


if __name__ == '__main__':
    main()