import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib as mpl
from scipy.stats import pearsonr
import os

from analysis import get_filtered_fit_info, get_behaviour_values, get_parameter_values, behaviour_correlations, parameter_correlations
from fitting.ModelFit import get_best_fit
from experiments.Baseline import BaselineModel, BaselineCellData
from experiments.FiCurve import FICurveModel, FICurveCellData
from parser.CellData import CellData
from my_util import functions as fu
from my_util import save_load
import Figure_constants as consts


parameter_titles = {"input_scaling": r"$\alpha$", "delta_a": r"$\Delta_A$",
                    "mem_tau": r"$\tau_m$", "noise_strength": r"$\sqrt{2D}$",
                    "refractory_period": "$t_{ref}$", "tau_a": r"$\tau_A$",
                    "v_offset": r"$I_{Bias}$", "dend_tau": r"$\tau_{dend}$"}

parameter_xlabels = {"input_scaling": "cm", "delta_a": r"$\Delta_A$",
                    "mem_tau": r"$\tau_m$", "noise_strength": r"$\sqrt{2D}$",
                    "refractory_period": "$t_{ref}$", "tau_a": r"$\tau_A$",
                    "v_offset": r"$I_{Bias}$", "dend_tau": r"$\tau_{dend}$"}

behaviour_titles = {"baseline_frequency": "Base Rate", "Burstiness": "Burst", "coefficient_of_variation": "CV",
                    "serial_correlation": "SC", "vector_strength": "VS",
                    "f_inf_slope": r"$f_{\infty}$ Slope", "f_zero_slope": r"$f_0$ Slope",
                    "f_zero_middle": r"$f_0$ middle", "eodf": "EODf"}


def main():
    # run_all_images()
    # quit()

    dir_path = "results/final_2/"

    # dend_tau_and_ref_effect()
    # quit()

    fits_info = get_filtered_fit_info(dir_path, filter=True)
    # visualize_tested_correlations(fits_info)
    quit()
    print("Cells left:", len(fits_info))
    cell_behaviour, model_behaviour = get_behaviour_values(fits_info)
    # plot_cell_model_comp_baseline(cell_behaviour, model_behaviour)
    # plot_cell_model_comp_burstiness(cell_behaviour, model_behaviour)
    plot_cell_model_comp_adaption(cell_behaviour, model_behaviour)

    behaviour_correlations_plot(fits_info)
    parameter_correlation_plot(fits_info)
    #
    # create_parameter_distributions(get_parameter_values(fits_info))
    # create_parameter_distributions(get_parameter_values(fits_info, scaled=True, goal_eodf=800), "scaled_to_800_")
    # errors = calculate_percent_errors(fits_info)
    # create_boxplots(errors)

    # example_bad_hist_fits(dir_path)
    # example_good_fi_fits(dir_path)
    # example_bad_fi_fits(dir_path)


def run_all_images(dir_path, filter=True,  pre_analysis_path="", recalculate=False):

    if pre_analysis_path != "":
        fit_info_name = "figures_res_fit_info.npy"
        behaviours_name = "figures_res_behaviour.npy"

        fit_info_path = os.path.join(pre_analysis_path, fit_info_name)
        if not os.path.exists(fit_info_path) or recalculate:
            fits_info = get_filtered_fit_info(dir_path, filter=filter)
            save_load.save(fits_info, fit_info_path)
        else:
            fits_info = save_load.load(fit_info_path)

        behaviours_path = os.path.join(pre_analysis_path, behaviours_name)
        if not os.path.exists(behaviours_path) or recalculate:
            cell_behaviour, model_behaviour = get_behaviour_values(fits_info)
            save_load.save([cell_behaviour, model_behaviour], behaviours_path)
        else:
            cell_behaviour, model_behaviour = save_load.load(behaviours_path)

    else:
        fits_info = get_filtered_fit_info(dir_path, filter=True)
        cell_behaviour, model_behaviour = get_behaviour_values(fits_info)

    plot_cell_model_comp_baseline(cell_behaviour, model_behaviour)
    plot_cell_model_comp_adaption(cell_behaviour, model_behaviour)
    plot_cell_model_comp_burstiness(cell_behaviour, model_behaviour)

    behaviour_correlations_plot(fits_info)
    parameter_correlation_plot(fits_info)

    create_parameter_distributions(get_parameter_values(fits_info))
    create_parameter_distributions(get_parameter_values(fits_info, scaled=True, goal_eodf=800), "scaled_to_800_")

    # Plots using example cells:

    # dend_tau_and_ref_effect()
    # example_good_hist_fits(dir_path)
    # example_bad_hist_fits(dir_path)
    # example_good_fi_fits(dir_path)
    # example_bad_fi_fits(dir_path)


def visualize_tested_correlations(fits_info):

    for leave_out in range(1, 11, 1):
        significance_count, total_count, labels = test_correlations(fits_info, leave_out, model_values=False)
        percentages = significance_count / total_count
        border = total_count * 0.01
        fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
        gs = gridspec.GridSpec(2, 2, width_ratios=(1, 1), height_ratios=(5, 0.5), hspace=0.5, wspace=0.4, left=0.2)

        ax = fig.add_subplot(gs[0, 0])
        # We want to show all ticks...

        ax.imshow(percentages)
        ax.set_xticks(np.arange(len(labels)))
        ax.set_xticklabels([behaviour_titles[l] for l in labels])
        # remove frame:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # ... and label them with the respective list entries
        ax.set_yticks(np.arange(len(labels)))
        ax.set_yticklabels([behaviour_titles[l] for l in labels])

        ax.set_title("Percent: removed {}".format(leave_out))

        # Rotate the tick labels and set their alignment.
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
                 rotation_mode="anchor")

        # Loop over data dimensions and create text annotations.
        for i in range(len(labels)):
            for j in range(len(labels)):
                if percentages[i, j] > 0.5:
                    text = ax.text(j, i, "{:.2f}".format(percentages[i, j]), ha="center", va="center",
                                   color="black", size=6)
                else:
                    text = ax.text(j, i, "{:.2f}".format(percentages[i, j]), ha="center", va="center",
                                   color="white", size=6)

        ax = fig.add_subplot(gs[0, 1])
        ax.imshow(percentages)
        ax.set_xticks(np.arange(len(labels)))
        ax.set_xticklabels([behaviour_titles[l] for l in labels])
        # remove frame:
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        # ... and label them with the respective list entries
        ax.set_yticks(np.arange(len(labels)))
        ax.set_yticklabels([behaviour_titles[l] for l in labels])

        ax.set_title("Counts - removed {}".format(leave_out))

        # Rotate the tick labels and set their alignment.
        plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
                 rotation_mode="anchor")

        # Loop over data dimensions and create text annotations.
        for i in range(len(labels)):
            for j in range(len(labels)):
                if percentages[i, j] > 0.5:
                    text = ax.text(j, i, "{:.0f}".format(significance_count[i, j]), ha="center", va="center",
                                   color="black", size=6)
                else:
                    text = ax.text(j, i, "{:.0f}".format(significance_count[i, j]), ha="center", va="center",
                                   color="white", size=6)


        ax_col = fig.add_subplot(gs[1, :])
        data = [np.arange(0, 1.001, 0.01)] * 10
        ax_col.set_xticks([0, 25, 50, 75, 100])
        ax_col.set_xticklabels([0, 0.25, 0.5, 0.75, 1])
        ax_col.set_yticks([])
        ax_col.imshow(data)
        ax_col.set_xlabel("Correlation Coefficients")


        plt.tight_layout()
        plt.savefig("figures/consistency_correlations_removed_{}.pdf".format(leave_out))


def test_correlations(fits_info, left_out, model_values=False):
    bv_cell, bv_model = get_behaviour_values(fits_info)
    # eod_frequencies = [fits_info[cell][3] for cell in sorted(fits_info.keys())]
    if model_values:
        behaviour_values = bv_model
    else:
        behaviour_values = bv_cell

    labels = ["baseline_frequency", "serial_correlation", "vector_strength", "coefficient_of_variation",
              "Burstiness", "f_inf_slope", "f_zero_slope"]  # , "eodf"]
    significance_counts = np.zeros((len(labels), len(labels)))
    correction_factor = sum(range(len(labels)))
    total_count = 0
    for mask in iall_masks(len(behaviour_values["f_inf_slope"]), left_out):
        total_count += 1
        idx = np.ones(len(behaviour_values["f_inf_slope"]), dtype=np.int32)
        for masked in mask:
            idx[masked] = 0
        for i in range(len(labels)):
            for j in range(len(labels)):
                if j > i:
                    continue
                idx = np.array(idx, dtype=np.bool)
                values_i = np.array(behaviour_values[labels[i]])[idx]
                values_j = np.array(behaviour_values[labels[j]])[idx]
                c, p = pearsonr(values_i, values_j)
                if p*correction_factor < 0.05:
                    significance_counts[i, j] += 1

    return significance_counts, total_count, labels


def iall_masks(values_count: int, left_out: int):
    mask = np.array(range(left_out))

    while True:
        if mask[0] == values_count - left_out + 1:
            break
        yield mask

        mask[-1] += 1

        if mask[-1] >= values_count:
            idx_to_start = 0
            for i in range(left_out-1):
                if mask[-1 - i] >= values_count-i:
                    mask[-1 - (i+1)] += 1
                    idx_to_start -= 1
                else:
                    break
            while idx_to_start < 0:
                # print("i:", idx_to_start, "mask:", mask)
                mask[idx_to_start] = mask[idx_to_start -1] + 1
                idx_to_start += 1
            # print("i:", idx_to_start, "mask:", mask, "end")


def dend_tau_and_ref_effect():
    cells = ["2012-12-21-am-invivo-1", "2014-03-19-ad-invivo-1", "2014-03-25-aa-invivo-1"]
    cell_type = ["no burster", "burster", "strong burster"]
    folders = ["results/ref_and_tau/no_dend_tau/", "results/ref_and_tau/no_ref_period/", "results/final_2/"]
    title = [r"without $\tau_{dend}$", r"without $t_{ref}$", "with both"]

    fig, axes = plt.subplots(len(cells), 3, figsize=consts.FIG_SIZE_LARGE, sharey="row", sharex="all")

    for i, cell in enumerate(cells):
        cell_data = CellData("data/final/" + cell)
        cell_baseline = BaselineCellData(cell_data)
        cell_baseline.load_values(cell_data.get_data_path())
        eodf = cell_data.get_eod_frequency()
        print(cell)
        print("EODf:", eodf)
        print("base rate:", cell_baseline.get_baseline_frequency())
        print("bursty:", cell_baseline.get_burstiness())
        print()
        for j, folder in enumerate(folders):
            fit = get_best_fit(folder + cell)
            model_baseline = BaselineModel(fit.get_model(), eodf)
            cell_isis = cell_baseline.get_interspike_intervals() * eodf
            model_isis = model_baseline.get_interspike_intervals() * eodf
            bins = np.arange(0, 0.025, 0.0001) * eodf
            if i == 0 and j == 2:
                axes[i, j].hist(model_isis, density=True, bins=bins, color=consts.COLOR_MODEL, alpha=0.75,
                                label="model")
                axes[i, j].hist(cell_isis, density=True, bins=bins, color=consts.COLOR_DATA, alpha=0.5, label="data")
                axes[i, j].legend(loc="upper right", frameon=False)
            else:
                axes[i, j].hist(model_isis, density=True, bins=bins, color=consts.COLOR_MODEL, alpha=0.75)
                axes[i, j].hist(cell_isis, density=True, bins=bins, color=consts.COLOR_DATA, alpha=0.5)

            if j == 0:
                axes[i, j].set_ylabel(cell_type[i])
                axes[i, j].set_yticklabels([])
            if i == 0:
                axes[0, j].set_title(title[j])

    plt.xlim(0, 17.5)
    fig.text(0.5, 0.04, 'Time [EOD periods]', ha='center', va='center')  # shared x label
    fig.text(0.06, 0.5, 'ISI Density', ha='center', va='center', rotation='vertical')  # shared y label

    fig.text(0.11, 0.9, 'A', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    fig.text(0.3825, 0.9, 'B', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    fig.text(0.655, 0.9, 'C', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    # fig.text(0.11, 0.86, '1', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    # fig.text(0.11, 0.59, '2', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    # fig.text(0.11, 0.32, '3', ha='center', va='center', rotation='horizontal', size=16, family='serif')

    plt.savefig(consts.SAVE_FOLDER + "dend_ref_effect.pdf", transparent=True)
    plt.close()


def create_parameter_distributions(par_values, prefix=""):
    fig, axes = plt.subplots(4, 2, gridspec_kw={"left": 0.1, "hspace": 0.5}, figsize=consts.FIG_SIZE_LARGE_HIGH)

    if len(par_values.keys()) != 8:
        print("not eight parameters")

    labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength",
              "tau_a", "delta_a", "dend_tau", "refractory_period"]
    x_labels = ["[cm]", "[mV]", "[ms]", r"[mV$\sqrt{s}$]", "[ms]", "[mVms]", "[ms]", "[ms]"]
    axes_flat = axes.flatten()
    for i, l in enumerate(labels):
        bins = calculate_bins(par_values[l], 20)
        if "ms" in x_labels[i]:
            bins *= 1000
            par_values[l] = np.array(par_values[l]) * 1000
        axes_flat[i].hist(par_values[l], bins=bins, color=consts.COLOR_MODEL, alpha=0.75)
        # axes_flat[i].set_title(parameter_titles[l])
        axes_flat[i].set_xlabel(parameter_titles[l] + " " + x_labels[i])
    fig.text(0.03, 0.5, 'Count', ha='center', va='center', rotation='vertical', size=12)  # shared y label
    plt.tight_layout()

    consts.set_figure_labels(xoffset=-2.5, yoffset=1.5)
    # fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + prefix + "parameter_distributions.pdf")
    plt.close()


def behaviour_correlations_plot(fits_info):
    fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
    gs = gridspec.GridSpec(2, 2, width_ratios=(1, 1), height_ratios=(5, 0.5), hspace=0.5, wspace=0.15, left=0.2)
    # fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_MEDIUM_WIDE)

    keys, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=False)
    labels = [behaviour_titles[k] for k in keys]
    img = create_correlation_plot(fig.add_subplot(gs[0, 0]), labels, corr_values, corrected_p_values, "Data")

    keys, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=True)
    labels = [behaviour_titles[k] for k in keys]
    ax = fig.add_subplot(gs[0, 1])
    img = create_correlation_plot(ax, labels, corr_values, corrected_p_values, "Model", y_label=False)
    # cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
    ax_col = fig.add_subplot(gs[1, :])
    data = [np.arange(-1, 1.001, 0.01)] * 10
    ax_col.set_xticks([0, 25, 50, 75, 100, 125, 150, 175, 200])
    ax_col.set_xticklabels([-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1])
    ax_col.set_yticks([])
    ax_col.imshow(data)
    ax_col.set_xlabel("Correlation Coefficients")
    plt.tight_layout()
    plt.savefig(consts.SAVE_FOLDER + "behaviour_correlations.pdf")
    plt.close()


def parameter_correlation_plot(fits_info):
    labels, corr_values, corrected_p_values = parameter_correlations(fits_info)
    par_labels = [parameter_titles[l] for l in labels]
    fig, ax = plt.subplots(1, 1, figsize=consts.FIG_SIZE_MEDIUM)
    # ax, labels, correlations, p_values, title, y_label=True
    im = create_correlation_plot(ax, par_labels, corr_values, corrected_p_values, "")
    fig.colorbar(im, ax=ax)
    plt.savefig(consts.SAVE_FOLDER + "parameter_correlations.pdf")
    plt.close()


def create_correlation_plot(ax, labels, correlations, p_values, title, y_label=True):
    cleaned_cors = np.zeros(correlations.shape)

    for i in range(correlations.shape[0]):
        for j in range(correlations.shape[1]):
            if abs(p_values[i, j]) < 0.05:
                cleaned_cors[i, j] = correlations[i, j]
            else:
                cleaned_cors[i, j] = np.NAN

            if j > i:
                cleaned_cors[i, j] = np.NAN
    im = ax.imshow(cleaned_cors, vmin=-1, vmax=1)

    # We want to show all ticks...
    ax.set_xticks(np.arange(len(labels)))
    ax.set_xticklabels(labels)
    # remove frame:
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    # ... and label them with the respective list entries
    if y_label:
        ax.set_yticks(np.arange(len(labels)))
        ax.set_yticklabels(labels)
    else:
        ax.set_yticklabels([])
    ax.set_title(title)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(len(labels)):
        for j in range(len(labels)):
            if not np.isnan(cleaned_cors[i, j]):
                if cleaned_cors[i, j] > 0:
                    text = ax.text(j, i, "{:.2f}".format(cleaned_cors[i, j]), ha="center", va="center", color="black", size=6)
                else:
                    text = ax.text(j, i, "{:.2f}".format(cleaned_cors[i, j]), ha="center", va="center", color="white", size=6)
            # if p_values[i][j] < 0.0001:
            #     text = ax.text(j, i, "***", ha="center", va="center", color="b")
            # elif p_values[i][j] < 0.001:
            #     text = ax.text(j, i, "**", ha="center", va="center", color="b")
            # elif p_values[i][j] < 0.05:
            #     text = ax.text(j, i, "*", ha="center", va="center", color="b")

    return im


def example_good_hist_fits(dir_path):
    strong_bursty_cell = "2018-05-08-ac-invivo-1"
    bursty_cell = "2014-03-19-ad-invivo-1"
    non_bursty_cell = "2012-12-21-am-invivo-1"

    fig, axes = plt.subplots(1, 3, sharex="all", figsize=(8, 4))

    for i, cell in enumerate([non_bursty_cell, bursty_cell, strong_bursty_cell]):
        fit_dir = dir_path + cell + "/"
        fit = get_best_fit(fit_dir)

        cell_data = fit.get_cell_data()
        eodf = cell_data.get_eod_frequency()

        model = fit.get_model()
        baseline_model = BaselineModel(model, eodf, trials=5)

        model_isi = np.array(baseline_model.get_interspike_intervals()) * eodf
        cell_isi = BaselineCellData(cell_data).get_interspike_intervals() * eodf

        bins = np.arange(0, 0.025, 0.0001) * eodf
        axes[i].hist(model_isi, bins=bins, density=True, alpha=0.75, color=consts.COLOR_MODEL)
        axes[i].hist(cell_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_DATA)
        axes[i].set_xlabel("ISI in EOD periods")
    axes[0].set_ylabel("Density")
    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5)
    fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + "example_good_isi_hist_fits.pdf", transparent=True)
    plt.close()


def example_bad_hist_fits(dir_path):
    bursty_cell = "2014-06-06-ag-invivo-1"
    strong_bursty_cell = "2018-05-08-ab-invivo-1"
    extra_structure_cell = "2014-12-11-ad-invivo-1"

    fig, axes = plt.subplots(1, 3, sharex="all", figsize=consts.FIG_SIZE_SMALL_EXTRA_WIDE)  # , gridspec_kw={"top": 0.95})

    for i, cell in enumerate([bursty_cell, strong_bursty_cell, extra_structure_cell]):


        fit_dir = dir_path + cell + "/"
        fit = get_best_fit(fit_dir)

        cell_data = fit.get_cell_data()
        eodf = cell_data.get_eod_frequency()

        model = fit.get_model()
        baseline_model = BaselineModel(model, eodf, trials=5)
        cell_baseline = BaselineCellData(cell_data)

        print(cell)
        print("EODf:", eodf)
        print("base rate:", cell_baseline.get_baseline_frequency())
        print("bursty:", cell_baseline.get_burstiness())
        print()

        model_isi = np.array(baseline_model.get_interspike_intervals()) * eodf
        cell_isi = cell_baseline.get_interspike_intervals() * eodf

        bins = np.arange(0, 0.025, 0.0001) * eodf
        if i == 0:
            axes[i].hist(model_isi, bins=bins, density=True, alpha=0.75, color=consts.COLOR_MODEL, label="model")
            axes[i].hist(cell_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_DATA, label="data")
            axes[i].legend(loc="upper right", frameon=False)
        else:
            axes[i].hist(model_isi, bins=bins, density=True, alpha=0.75, color=consts.COLOR_MODEL)
            axes[i].hist(cell_isi, bins=bins, density=True, alpha=0.5, color=consts.COLOR_DATA)

        axes[i].set_xlabel("ISI [EOD periods]")
    axes[0].set_ylabel("Density")
    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5, yoffset=1.25)
    fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + "example_bad_isi_hist_fits.pdf", transparent=True)
    plt.close()


def example_good_fi_fits(dir_path):

    fig, axes = plt.subplots(1, 3, figsize=consts.FIG_SIZE_SMALL_EXTRA_WIDE, sharey="all")
    for i, cell in enumerate(["2012-12-21-am-invivo-1", "2014-03-19-ae-invivo-1", "2014-03-25-aa-invivo-1" ]):
        fit_dir = dir_path + cell + "/"
        fit = get_best_fit(fit_dir)

        cell_data = fit.get_cell_data()
        eodf = cell_data.get_eod_frequency()

        cell_baseline = BaselineCellData(cell_data)

        print(cell)
        print("EODf:", eodf)
        print("base rate:", cell_baseline.get_baseline_frequency())
        print("bursty:", cell_baseline.get_burstiness())
        print()


        model = fit.get_model()
        contrasts = cell_data.get_fi_contrasts()
        fi_curve_data = FICurveCellData(cell_data, contrasts, save_dir=cell_data.get_data_path())
        contrasts = fi_curve_data.stimulus_values
        x_values = np.arange(min(contrasts), max(contrasts), 0.001)
        fi_curve_model = FICurveModel(model, contrasts,  eodf, trials=10)

        f_zero_fit = fi_curve_data.f_zero_fit
        f_inf_fit = fi_curve_data.f_inf_fit

        # f zero response
        axes[i].plot(contrasts, fi_curve_data.get_f_zero_frequencies(), ',',
                     marker=consts.f0_marker, alpha=0.75, color=consts.COLOR_DATA_f0, label=r"data $f_0$")
        axes[i].plot(x_values, fu.full_boltzmann(x_values, f_zero_fit[0], f_zero_fit[1], f_zero_fit[2], f_zero_fit[3]),
                     color=consts.COLOR_DATA_f0, alpha=0.75)
        axes[i].plot(contrasts, fi_curve_model.get_f_zero_frequencies(), ',',
                     marker=consts.f0_marker, alpha=0.75, color=consts.COLOR_MODEL_f0, label=r"model $f_0$")

        # f inf response
        axes[i].plot(contrasts, fi_curve_data.get_f_inf_frequencies(), ',',
                     marker=consts.finf_marker, alpha=0.5, color=consts.COLOR_DATA_finf, label=r"data $f_{\infty}$")
        axes[i].plot(x_values, fu.clipped_line(x_values, f_inf_fit[0], f_inf_fit[1]),
                     color=consts.COLOR_DATA_finf, alpha=0.5)
        axes[i].plot(contrasts, fi_curve_model.get_f_inf_frequencies(), ',',
                     marker=consts.finf_marker, alpha=0.75, color=consts.COLOR_MODEL_finf, label=r"model $f_{\infty}$")

        axes[i].set_xlabel("Contrast")
        axes[i].set_xlim((-0.22, 0.22))

    axes[0].legend(loc="upper left", frameon=False)
    axes[0].set_ylabel("Frequency [Hz]")
    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5)
    fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + "example_good_fi_fits.pdf", transparent=True)
    plt.close()


def example_bad_fi_fits(dir_path):
    fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_SMALL_EXTRA_WIDE)
    #  "2013-01-08-aa-invivo-1" candidate cell
    for i, cell in enumerate(["2012-12-13-ao-invivo-1", "2014-01-23-ab-invivo-1"]):
        fit_dir = dir_path + cell + "/"
        fit = get_best_fit(fit_dir)

        cell_data = fit.get_cell_data()
        eodf = cell_data.get_eod_frequency()

        cell_baseline = BaselineCellData(cell_data)

        print(cell)
        print("EODf:", eodf)
        print("base rate:", cell_baseline.get_baseline_frequency())
        print("bursty:", cell_baseline.get_burstiness())
        print()

        model = fit.get_model()
        contrasts = cell_data.get_fi_contrasts()
        fi_curve_data = FICurveCellData(cell_data, contrasts, save_dir=cell_data.get_data_path())
        contrasts = fi_curve_data.stimulus_values
        x_values = np.arange(min(contrasts), max(contrasts), 0.001)
        fi_curve_model = FICurveModel(model, contrasts, eodf, trials=10)

        f_zero_fit = fi_curve_data.f_zero_fit
        f_inf_fit = fi_curve_data.f_inf_fit

        # f zero response
        axes[i].plot(contrasts, fi_curve_data.get_f_zero_frequencies(), ',',
                     marker=consts.f0_marker, alpha=0.75, color=consts.COLOR_DATA_f0, label=r"data $f_0$")
        axes[i].plot(x_values, fu.full_boltzmann(x_values, f_zero_fit[0], f_zero_fit[1], f_zero_fit[2], f_zero_fit[3]),
                     color=consts.COLOR_DATA_f0, alpha=0.75)
        axes[i].plot(contrasts, fi_curve_model.get_f_zero_frequencies(), ',',
                     marker=consts.f0_marker, alpha=0.75, color=consts.COLOR_MODEL_f0, label=r"model $f_0$")

        # f inf response
        axes[i].plot(contrasts, fi_curve_data.get_f_inf_frequencies(), ',',
                     marker=consts.finf_marker, alpha=0.5, color=consts.COLOR_DATA_finf, label=r"data $f_{\infty}$")
        axes[i].plot(x_values, fu.clipped_line(x_values, f_inf_fit[0], f_inf_fit[1]),
                     color=consts.COLOR_DATA_finf, alpha=0.5)
        axes[i].plot(contrasts, fi_curve_model.get_f_inf_frequencies(), ',',
                     marker=consts.finf_marker, alpha=0.75, color=consts.COLOR_MODEL_finf, label=r"model $f_{\infty}$")

        axes[i].set_xlabel("Contrast")
        axes[i].set_xlim((-0.22, 0.2))

    axes[0].set_ylabel("Frequency [Hz]")
    axes[0].legend(loc="upper left", frameon=False)
    plt.tight_layout()
    consts.set_figure_labels(xoffset=-2.5)
    fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + "example_bad_fi_fits.pdf", transparent=True)
    plt.close()


def create_boxplots(errors):
    labels = ["{}_n:{}".format(k, len(errors[k])) for k in sorted(errors.keys())]
    for k in sorted(errors.keys()):
        print("{}: median %-error: {:.2f}".format(k, np.median(errors[k])))
    y_values = [errors[k] for k in sorted(errors.keys())]

    plt.boxplot(y_values)
    plt.xticks(np.arange(1, len(y_values)+1, 1), labels, rotation=45)
    plt.tight_layout()
    plt.show()
    plt.close()


def plot_cell_model_comp_baseline(cell_behavior, model_behaviour):
    fig = plt.figure(figsize=(8, 4))
    gs = fig.add_gridspec(2, 3, width_ratios=[5, 5, 5], height_ratios=[3, 7],
                          left=0.1, right=0.95, bottom=0.1, top=0.9,
                          wspace=0.4, hspace=0.2)
    num_of_bins = 20
    cmap = 'jet'
    cell_bursting = cell_behavior["Burstiness"]
    # baseline freq plot:
    i = 0
    cell = cell_behavior["baseline_frequency"]
    model = model_behaviour["baseline_frequency"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["baseline_frequency"], bins)  # , cmap, cell_bursting)
    ax.set_xlabel(r"Cell [Hz]")
    ax.set_ylabel(r"Model [Hz]")
    ax_histx.set_ylabel("Count")
    i += 1

    cell = cell_behavior["vector_strength"]
    model = model_behaviour["vector_strength"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    print("Cells in cell_model_comp_baseline:", len(cell))
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["vector_strength"], bins)  # , cmap, cell_bursting)
    ax.set_xlabel(r"Cell")
    ax.set_ylabel(r"Model")
    ax_histx.set_ylabel("Count")
    i += 1

    cell = cell_behavior["serial_correlation"]
    model = model_behaviour["serial_correlation"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["serial_correlation"], bins)  # , cmap, cell_bursting)
    ax.set_xlabel(r"Cell")
    ax.set_ylabel(r"Model")
    fig.text(0.09, 0.925, 'A', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    fig.text(0.375, 0.925, 'B', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    fig.text(0.6625, 0.925, 'C', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    ax_histx.set_ylabel("Count")
    i += 1

    plt.tight_layout()
    plt.savefig(consts.SAVE_FOLDER + "fit_baseline_comparison.pdf", transparent=True)
    plt.close()


def plot_cell_model_comp_burstiness(cell_behavior, model_behaviour):
    fig = plt.figure(figsize=consts.FIG_SIZE_MEDIUM_WIDE)

    # ("Burstiness", "coefficient_of_variation")
    # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
    # the size of the marginal axes and the main axes in both directions.
    # Also adjust the subplot parameters for a square plot.
    gs = fig.add_gridspec(2, 2, width_ratios=[5, 5], height_ratios=[3, 7],
                          left=0.1, right=0.9, bottom=0.1, top=0.9,
                          wspace=0.3, hspace=0.2)
    num_of_bins = 20
    # baseline freq plot:
    i = 0
    cmap = 'jet'
    cell = cell_behavior["Burstiness"]
    cell_bursting = cell
    model = model_behaviour["Burstiness"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax.set_xlabel("Cell [%ms]")
    ax.set_ylabel("Model [%ms]")
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    ax_histx.set_ylabel("Count")
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["Burstiness"], bins, cmap, cell_bursting)
    i += 1

    cell = cell_behavior["coefficient_of_variation"]
    model = model_behaviour["coefficient_of_variation"]

    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["coefficient_of_variation"], bins, cmap, cell_bursting)

    ax.set_xlabel("Cell")
    ax.set_ylabel("Model")
    ax_histx.set_ylabel("Count")

    plt.tight_layout()

    fig.text(0.085, 0.925, 'A', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    fig.text(0.53, 0.925, 'B', ha='center', va='center', rotation='horizontal', size=16, family='serif')

    plt.savefig(consts.SAVE_FOLDER + "fit_burstiness_comparison.pdf", transparent=True)
    plt.close()


def plot_cell_model_comp_adaption(cell_behavior, model_behaviour):
    fig = plt.figure(figsize=(8, 4))
    gs = fig.add_gridspec(2, 3, width_ratios=[5, 5, 5], height_ratios=[3, 7],
                          left=0.1, right=0.95, bottom=0.1, top=0.9,
                          wspace=0.4, hspace=0.3)
    # ("f_inf_slope", "f_zero_slope")
    # Add a gridspec with two rows and two columns and a ratio of 2 to 7 between
    # the size of the marginal axes and the main axes in both directions.
    # Also adjust the subplot parameters for a square plot.
    mpl.rc("axes.formatter", limits=(-5, 3))
    num_of_bins = 20

    # baseline freq plot:
    i = 0
    cell = cell_behavior["f_inf_slope"]
    model = model_behaviour["f_inf_slope"]
    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)

    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["f_inf_slope"], bins)
    ax.set_xlabel(r"Cell [Hz]")
    ax.set_ylabel(r"Model [Hz]")
    ax_histx.set_ylabel("Count")
    i += 1

    cell = cell_behavior["f_zero_slope"]
    model = model_behaviour["f_zero_slope"]
    length_before = len(cell)
    idx = np.array(cell) < 25000
    cell = np.array(cell)[idx]
    model = np.array(model)[idx]

    idx = np.array(model) < 25000
    cell = np.array(cell)[idx]
    model = np.array(model)[idx]

    print("removed {} values from f_zero_slope plot.".format(length_before - len(cell)))

    minimum = min(min(cell), min(model))
    maximum = max(max(cell), max(model))
    step = (maximum - minimum) / num_of_bins
    bins = np.arange(minimum, maximum + step, step)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell, model, ax, ax_histx, behaviour_titles["f_zero_slope"], bins)
    ax.set_xlabel("Cell [Hz]")
    ax.set_ylabel("Model [Hz]")
    ax_histx.set_ylabel("Count")
    i += 1

    # ratio:
    cell_inf = cell_behavior["f_inf_slope"]
    model_inf = model_behaviour["f_inf_slope"]
    cell_zero = cell_behavior["f_zero_slope"]
    model_zero = model_behaviour["f_zero_slope"]

    cell_ratio = [cell_zero[i]/cell_inf[i] for i in range(len(cell_inf))]
    model_ratio = [model_zero[i]/model_inf[i] for i in range(len(model_inf))]

    idx = np.array(cell_ratio) < 60
    cell_ratio = np.array(cell_ratio)[idx]
    model_ratio = np.array(model_ratio)[idx]

    idx = np.array(model_ratio) < 60
    cell_ratio = np.array(cell_ratio)[idx]
    model_ratio = np.array(model_ratio)[idx]

    both_ratios = list(cell_ratio.copy())
    both_ratios.extend(model_ratio)

    bins = calculate_bins(both_ratios, num_of_bins)

    ax = fig.add_subplot(gs[1, i])
    ax_histx = fig.add_subplot(gs[0, i], sharex=ax)
    scatter_hist(cell_ratio, model_ratio, ax, ax_histx, r"$f_0$ / $f_{\infty}$ slope ratio", bins)
    ax.set_xlabel("Cell")
    ax.set_ylabel("Model")
    ax_histx.set_ylabel("Count")

    plt.tight_layout()

    # fig.text(0.085, 0.925, 'A', ha='center', va='center', rotation='horizontal', size=16, family='serif')
    # fig.text(0.54, 0.925, 'B', ha='center', va='center', rotation='horizontal', size=16, family='serif')

    plt.savefig(consts.SAVE_FOLDER + "fit_adaption_comparison_with_ratio.pdf", transparent=True)
    plt.close()

    mpl.rc("axes.formatter", limits=(-5, 6))


def scatter_hist(cell_values, model_values, ax, ax_histx, behaviour, bins, cmap=None, color_values=None):
    # copied from matplotlib

    # the scatter plot:
    minimum = min(min(cell_values), min(model_values))
    maximum = max(max(cell_values), max(model_values))
    ax.plot((minimum, maximum), (minimum, maximum), color="grey")
    if cmap is None:
        ax.scatter(cell_values, model_values, color="black")
    else:
        ax.scatter(cell_values, model_values, c=color_values, cmap=cmap)
    ax_histx.hist(model_values, bins=bins, color=consts.COLOR_MODEL, alpha=0.75)
    ax_histx.hist(cell_values, bins=bins, color=consts.COLOR_DATA, alpha=0.50)

    ax_histx.set_title(behaviour)


def calculate_bins(values, num_of_bins):
    minimum = np.min(values)
    maximum = np.max(values)
    step = (maximum - minimum) / (num_of_bins-1)

    bins = np.arange(minimum-0.5*step, maximum + step, step)
    return bins


if __name__ == '__main__':
    main()