import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from analysis import get_filtered_fit_info, get_behaviour_values, get_parameter_values, behaviour_correlations, parameter_correlations
from ModelFit import get_best_fit
from Baseline import BaselineModel, BaselineCellData
import Figure_constants as consts


parameter_titles = {"input_scaling": r"$\alpha$", "delta_a": r"$\Delta_A$",
                    "mem_tau": r"$\tau_m$", "noise_strength": r"$\sqrt{2D}$",
                    "refractory_period": "$t_{ref}$", "tau_a": r"$\tau_A$",
                    "v_offset": r"$I_{Bias}$", "dend_tau": r"$\tau_{dend}$"}

behaviour_titles = {"baseline_frequency": "base rate", "Burstiness": "Burst", "coefficient_of_variation": "CV",
                    "serial_correlation": "SC", "vector_strength": "VS",
                    "f_inf_slope": r"$f_{\infty}$ Slope", "f_zero_slope": r"$f_0$ Slope",
                    "f_zero_middle": r"$f_0$ middle"}


def main():
    dir_path = "results/final_2/"
    fits_info = get_filtered_fit_info(dir_path)
    print("Cells left:", len(fits_info))
    cell_behaviour, model_behaviour = get_behaviour_values(fits_info)
    plot_cell_model_comp_baseline(cell_behaviour, model_behaviour)
    plot_cell_model_comp_adaption(cell_behaviour, model_behaviour)
    plot_cell_model_comp_burstiness(cell_behaviour, model_behaviour)
    #
    behaviour_correlations_plot(fits_info)
    #
    labels, corr_values, corrected_p_values = parameter_correlations(fits_info)
    par_labels = [parameter_titles[l] for l in labels]
    fig, ax = plt.subplots(1, 1)
    #ax, labels, correlations, p_values, title, y_label=True
    create_correlation_plot(ax, par_labels, corr_values, corrected_p_values, "")
    plt.savefig(consts.SAVE_FOLDER + "parameter_correlations.png")
    plt.close()

    # create_parameter_distributions(get_parameter_values(fits_info))


    # errors = calculate_percent_errors(fits_info)
    # create_boxplots(errors)

    # example_good_hist_fits(dir_path)


def create_parameter_distributions(par_values):

    fig, axes = plt.subplots(4, 2)

    if len(par_values.keys()) != 8:
        print("not eight parameters")

    labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength",
              "tau_a", "delta_a", "dend_tau", "refractory_period"]
    axes_flat = axes.flatten()
    for i, l in enumerate(labels):
        min_v = min(par_values[l]) * 0.95
        max_v = max(par_values[l]) * 1.05
        step = (max_v - min_v) / 20
        bins = np.arange(min_v, max_v+step, step)
        axes_flat[i].hist(par_values[l], bins=bins, color=consts.COLOR_MODEL, alpha=0.75)
        axes_flat[i].set_title(parameter_titles[l])

    plt.tight_layout()
    plt.savefig(consts.SAVE_FOLDER + "parameter_distributions.png")
    plt.close()


def behaviour_correlations_plot(fits_info):

    fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
    gs = gridspec.GridSpec(2, 2, width_ratios=(1, 1), height_ratios=(5, 1), hspace=0.025, wspace=0.05)
    # fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_MEDIUM_WIDE)

    keys, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=False)
    labels = [behaviour_titles[k] for k in keys]
    img = create_correlation_plot(fig.add_subplot(gs[0, 0]), labels, corr_values, corrected_p_values, "Data")

    keys, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=True)
    labels = [behaviour_titles[k] for k in keys]
    img = create_correlation_plot(fig.add_subplot(gs[0, 1]), labels, corr_values, corrected_p_values, "Model", y_label=False)

    ax_col = fig.add_subplot(gs[1, :])
    data = [np.arange(-1, 1.001, 0.01)] * 10
    ax_col.set_xticks([0, 25, 50, 75, 100, 125, 150, 175, 200])
    ax_col.set_xticklabels([-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1])
    ax_col.set_yticks([])
    ax_col.imshow(data)
    ax_col.set_xlabel("Correlation Coefficients")
    plt.tight_layout()
    plt.savefig(consts.SAVE_FOLDER + "behaviour_correlations.png")
    plt.close()


def create_correlation_plot(ax, labels, correlations, p_values, title, y_label=True):

    cleaned_cors = np.zeros(correlations.shape)

    for i in range(correlations.shape[0]):
        for j in range(correlations.shape[1]):
            if abs(p_values[i, j]) < 0.05:
                cleaned_cors[i, j] = correlations[i, j]
            else:
                cleaned_cors[i, j] = np.NAN
    im = ax.imshow(cleaned_cors, vmin=-1, vmax=1)

    # We want to show all ticks...
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels)

    # ... and label them with the respective list entries
    if y_label:
        ax.set_yticks(np.arange(len(labels)))
        ax.set_yticklabels(labels)
    else:
        ax.set_yticklabels([])
    ax.set_title(title)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(len(labels)):
        for j in range(len(labels)):
            if p_values[i][j] < 0.0001:
                text = ax.text(j, i, "***", ha="center", va="center", color="b")
            elif p_values[i][j] < 0.001:
                text = ax.text(j, i, "**", ha="center", va="center", color="b")
            elif p_values[i][j] < 0.05:
                text = ax.text(j, i, "*", ha="center", va="center", color="b")

    return im

def example_good_hist_fits(dir_path):
    strong_bursty_cell = "2018-05-08-ac-invivo-1"
    bursty_cell = "2014-03-19-ad-invivo-1"
    non_bursty_cell = "2012-12-21-am-invivo-1"

    fig, axes = plt.subplots(1, 3, sharex="all", figsize=consts.FIG_SIZE_MEDIUM_WIDE)

    for i, cell in enumerate([non_bursty_cell, bursty_cell, strong_bursty_cell]):
        fit_dir = dir_path + cell + "/"
        fit = get_best_fit(fit_dir)

        cell_data = fit.get_cell_data()
        eodf = cell_data.get_eod_frequency()

        model = fit.get_model()
        baseline_model = BaselineModel(model, eodf, trials=5)

        model_isi = np.array(baseline_model.get_interspike_intervals()) * eodf
        cell_isi = BaselineCellData(cell_data).get_interspike_intervals() * eodf

        bins = np.arange(0, 0.025, 0.0001) * eodf
        axes[i].hist(cell_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_DATA)
        axes[i].hist(model_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_MODEL)
        axes[i].set_xlabel("ISI in EOD periods")
    axes[0].set_ylabel("Density")
    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5)
    fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + "example_good_isi_hist_fits.png", transparent=True)
    plt.close()


def create_boxplots(errors):

    labels = ["{}_n:{}".format(k, len(errors[k])) for k in sorted(errors.keys())]
    for k in sorted(errors.keys()):
        print("{}: median %-error: {:.2f}".format(k, np.median(errors[k])))
    y_values = [errors[k] for k in sorted(errors.keys())]

    plt.boxplot(y_values)
    plt.xticks(np.arange(1, len(y_values)+1, 1), labels, rotation=45)
    plt.tight_layout()
    plt.show()
    plt.close()



def plot_cell_model_comp_baseline(cell_behavior, model_behaviour):
    fig = plt.figure(figsize=(12, 6))
    # ("baseline_frequency", "vector_strength", "serial_correlation")

    # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
    # the size of the marginal axes and the main axes in both directions.
    # Also adjust the subplot parameters for a square plot.
    gs = fig.add_gridspec(2, 3, width_ratios=[5, 5, 5], height_ratios=[3, 7],
                          left=0.1, right=0.9, bottom=0.1, top=0.9,
                          wspace=0.25, hspace=0.05)
    num_of_bins = 20
    # baseline freq plot:
    i = 0
    cell = cell_behavior["baseline_frequency"]
    model = model_behaviour["baseline_frequency"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["baseline_frequency"], bins)
    i += 1

    cell = cell_behavior["vector_strength"]
    model = model_behaviour["vector_strength"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["vector_strength"], bins)
    i += 1

    cell = cell_behavior["serial_correlation"]
    model = model_behaviour["serial_correlation"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["serial_correlation"], bins)
    i += 1

    plt.tight_layout()
    plt.savefig(consts.SAVE_FOLDER + "fit_baseline_comparison.png", transparent=True)
    plt.close()


def plot_cell_model_comp_adaption(cell_behavior, model_behaviour):
    fig = plt.figure(figsize=(8, 6))

    # ("f_inf_slope", "f_zero_slope")
    # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
    # the size of the marginal axes and the main axes in both directions.
    # Also adjust the subplot parameters for a square plot.
    gs = fig.add_gridspec(2, 2, width_ratios=[5, 5], height_ratios=[3, 7],
                          left=0.1, right=0.9, bottom=0.1, top=0.9,
                          wspace=0.25, hspace=0.05)
    num_of_bins = 20
    # baseline freq plot:
    i = 0
    cell = cell_behavior["f_inf_slope"]
    model = model_behaviour["f_inf_slope"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["f_inf_slope"], bins)
    i += 1

    cell = cell_behavior["f_zero_slope"]
    model = model_behaviour["f_zero_slope"]
    length_before = len(cell)
    idx = np.array(cell) < 25000
    cell = np.array(cell)[idx]
    model = np.array(model)[idx]

    idx = np.array(model) < 25000
    cell = np.array(cell)[idx]
    model = np.array(model)[idx]
    print("removed {} values from f_zero_slope plot.".format(length_before - len(cell)))

    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["f_zero_slope"], bins)

    plt.tight_layout()
    plt.savefig(consts.SAVE_FOLDER + "fit_adaption_comparison.png", transparent=True)
    plt.close()


def plot_cell_model_comp_burstiness(cell_behavior, model_behaviour):
    fig = plt.figure(figsize=(8, 6))

    # ("Burstiness", "coefficient_of_variation")
    # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
    # the size of the marginal axes and the main axes in both directions.
    # Also adjust the subplot parameters for a square plot.
    gs = fig.add_gridspec(2, 2, width_ratios=[5, 5], height_ratios=[3, 7],
                          left=0.1, right=0.9, bottom=0.1, top=0.9,
                          wspace=0.25, hspace=0.05)
    num_of_bins = 20
    # baseline freq plot:
    i = 0
    cell = cell_behavior["Burstiness"]
    model = model_behaviour["Burstiness"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["Burstiness"], bins)
    i += 1

    cell = cell_behavior["coefficient_of_variation"]
    model = model_behaviour["coefficient_of_variation"]

    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["coefficient_of_variation"], bins)

    plt.tight_layout()
    plt.savefig(consts.SAVE_FOLDER + "fit_burstiness_comparison.png", transparent=True)
    plt.close()


# def behaviour_overview_pairs(cell_behaviour, model_behaviour):
#     # behaviour_keys = ["Burstiness", "coefficient_of_variation", "serial_correlation",
#     #                   "vector_strength", "f_inf_slope", "f_zero_slope", "baseline_frequency"]
#
#     pairs = [("baseline_frequency", "vector_strength", "serial_correlation"),
#              ("Burstiness", "coefficient_of_variation"),
#              ("f_inf_slope", "f_zero_slope")]
#
#     for pair in pairs:
#         cell = []
#         model = []
#         for behaviour in pair:
#             cell.append(cell_behaviour[behaviour])
#             model.append(model_behaviour[behaviour])
#         overview_pair(cell, model, pair)
#
#
# def overview_pair(cell, model, titles):
#     fig = plt.figure(figsize=(8, 6))
#
#     columns = len(cell)
#
#     # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
#     # the size of the marginal axes and the main axes in both directions.
#     # Also adjust the subplot parameters for a square plot.
#     gs = fig.add_gridspec(2, columns, width_ratios=[5] * columns, height_ratios=[3, 7],
#                           left=0.1, right=0.9, bottom=0.1, top=0.9,
#                           wspace=0.2, hspace=0.05)
#
#     for i in range(len(cell)):
#         if titles[i] == "f_zero_slope":
#             length_before = len(cell[i])
#             idx = np.array(cell[i]) < 30000
#             cell[i] = np.array(cell[i])[idx]
#             model[i] = np.array(model[i])[idx]
#
#             idx = np.array(model[i]) < 30000
#             cell[i] = np.array(cell[i])[idx]
#             model[i] = np.array(model[i])[idx]
#             print("removed {} values from f_zero_slope plot.".format(length_before - len(cell[i])))
#         ax = fig.add_subplot(gs[1, i])
#         ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
#         scatter_hist(cell[i], model[i], ax, ax_histx, titles[i])
#
#     # plt.tight_layout()
#     plt.show()
#
#
# def grouped_error_overview_behaviour_dist(cell_behaviours, model_behaviours):
#     # start with a square Figure
#     fig = plt.figure(figsize=(12, 12))
#
#     rows = 4
#     columns = 2
#     # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
#     # the size of the marginal axes and the main axes in both directions.
#     # Also adjust the subplot parameters for a square plot.
#     gs = fig.add_gridspec(rows*2, columns, width_ratios=[5]*columns, height_ratios=[3, 7] * rows,
#                           left=0.1, right=0.9, bottom=0.1, top=0.9,
#                           wspace=0.2, hspace=0.5)
#
#     for i, behaviour in enumerate(sorted(cell_behaviours.keys())):
#         col = int(np.floor(i / rows))
#         row = i - rows*col
#         ax = fig.add_subplot(gs[row*2 + 1, col])
#         ax_histx = fig.add_subplot(gs[row*2, col])
#
#         # use the previously defined function
#         scatter_hist(cell_behaviours[behaviour], model_behaviours[behaviour], ax, ax_histx, behaviour)
#
#     plt.tight_layout()
#     plt.show()
#

def scatter_hist(cell_values, model_values, ax, ax_histx, behaviour, bins, ax_histy=None):
    # copied from matplotlib

    # no labels
    ax_histx.tick_params(axis="cell", labelbottom=False)
    # ax_histy.tick_params(axis="model_values", labelleft=False)
    # the scatter plot:
    minimum = min(min(cell_values), min(model_values))
    maximum = max(max(cell_values), max(model_values))
    ax.plot((minimum, maximum), (minimum, maximum), color="grey")
    ax.scatter(cell_values, model_values, color="black")

    ax.set_xlabel("cell")
    ax.set_ylabel("model")

    ax_histx.hist(cell_values, bins=bins, color=consts.COLOR_DATA, alpha=0.75)
    ax_histx.hist(model_values, bins=bins, color=consts.COLOR_MODEL, alpha=0.75)
    ax_labels = ax.get_xticklabels()
    ax_histx.set_xticklabels([])
    ax.set_xticklabels(ax_labels)
    ax_histx.set_xticks(ax.get_xticks())
    ax_histx.set_xlim(ax.get_xlim())
    ax_histx.set_title(behaviour)

    # ax_histy.hist(y, bins=bins, orientation='horizontal')


if __name__ == '__main__':
    main()