import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from scipy.optimize import curve_fit
from scipy.stats import multivariate_normal, pearsonr

from analysis import get_parameter_values, get_filtered_fit_info, parameter_correlations, get_behaviour_values
from fitting.ModelFit import get_best_fit
from my_util import functions as fu
from experiments.Baseline import BaselineModel
from experiments.FiCurve import FICurveModel
from models.LIFACnoise import LifacNoiseModel
from Figures_results import create_correlation_plot
import Figure_constants as consts

LOG_TRANSFORM = {"v_offset": False, 'input_scaling': True, 'dend_tau': True, 'tau_a': True, 'delta_a': True,
                 'refractory_period': False, 'noise_strength': True, 'mem_tau': True}

behaviour_titles = {"baseline_frequency": "Base Rate", "Burstiness": "Burst", "coefficient_of_variation": "CV",
                    "serial_correlation": "SC", "vector_strength": "VS",
                    "f_inf_slope": r"$f_{\infty}$ Slope", "f_zero_slope": r"$f_0$ Slope"}

parameter_titles = {"input_scaling": r"$\alpha$", "delta_a": r"$\Delta_A$",
                    "mem_tau": r"$\tau_m$", "noise_strength": r"$\sqrt{2D}$",
                    "refractory_period": "$t_{ref}$", "tau_a": r"$\tau_A$",
                    "v_offset": r"$I_{Bias}$", "dend_tau": r"$\tau_{dend}$"}


recalculate = False
num_of_models = 100


def main():

    rerun_all_images()
    quit()

    folder = "results/final_2/"
    fit_infos = get_filtered_fit_info(folder, filter=True)
    goal_eodf = 800
    param_values = get_parameter_values(fit_infos, scaled=True, goal_eodf=goal_eodf)

    # plots 1
    keys, means, cov_matrix = calculate_means_and_covariances(param_values)
    par_list = draw_random_models(1000, keys, means, cov_matrix, seed=1)
    parameter_correlation_plot(par_list, fit_infos)
    plot_distributions_with_set_fits(param_values)

    if recalculate:
        keys, means, cov_matrix = calculate_means_and_covariances(param_values)
        par_list = draw_random_models(num_of_models, keys, means, cov_matrix)

        behaviour = model_behaviour_distributions(par_list, eodf=goal_eodf)

        save_behaviour(behaviour, par_list)
    else:
        behaviour, par_list = load_behavior()
    create_behaviour_distributions(behaviour, fit_infos)
    compare_distribution_random_vs_fitted_params(par_list, param_values)


def rerun_all_images():

    folder = "results/final_2/"
    fit_infos = get_filtered_fit_info(folder, filter=True)
    goal_eodf = 800
    param_values = get_parameter_values(fit_infos, scaled=True, goal_eodf=goal_eodf)

    keys, means, cov_matrix = calculate_means_and_covariances(param_values)
    par_list = draw_random_models(1000, keys, means, cov_matrix, seed=1)
    parameter_correlation_plot(par_list, fit_infos)
    plot_distributions_with_set_fits(param_values)

    behaviour, par_list = load_behavior()
    create_behaviour_distributions(behaviour, fit_infos)
    compare_distribution_random_vs_fitted_params(par_list, param_values)


def compare_distribution_random_vs_fitted_params(par_list, scaled_param_values):
    labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength",
              "tau_a", "delta_a", "dend_tau", "refractory_period"]
    x_labels = ["[cm]", "[mV]", "[ms]", r"[mV$\sqrt{s}$]", "[ms]", "[mVms]", "[ms]", "[ms]"]
    model_parameter_values = {}
    for l in labels:
        model_parameter_values[l] = []

    for params in par_list:
        for l in labels:
            model_parameter_values[l].append(params[l])

    fig, axes = plt.subplots(4, 2, gridspec_kw={"left": 0.1, "hspace":0.5}, figsize=consts.FIG_SIZE_LARGE_HIGH)
    axes_flat = axes.flatten()
    for i, l in enumerate(labels):
        rand_model_values = model_parameter_values[l]
        fitted_model_values = scaled_param_values[l]

        if "ms" in x_labels[i]:
            rand_model_values = np.array(rand_model_values) * 1000
            fitted_model_values = np.array(fitted_model_values) * 1000

        min_v = min(min(rand_model_values), min(fitted_model_values)) * 0.95
        max_v = max(max(rand_model_values), max(fitted_model_values)) * 1.05
        # limit = 50000
        # if max_v > limit:
        #     print("For {} the max value was limited to {},  {} values were excluded!".format(l, limit, np.sum(np.array(cell_b_values[l]) > limit)))
        #     max_v = limit
        values = list(rand_model_values)
        values.extend(fitted_model_values)
        bins = calculate_bins(values, 30)
        axes_flat[i].hist(fitted_model_values, bins=bins, color=consts.COLOR_MODEL, alpha=0.75, density=True)
        axes_flat[i].hist(rand_model_values, bins=bins, color="black", alpha=0.5, density=True)
        axes_flat[i].set_xlabel(parameter_titles[l] + " " + x_labels[i])
        axes_flat[i].set_yticks([])
        axes_flat[i].set_yticklabels([])

    fig.text(0.03, 0.5, 'Density', ha='center', va='center', rotation='vertical', size=12)  # shared y label
    plt.tight_layout()

    consts.set_figure_labels(xoffset=-2.5, yoffset=0)
    fig.label_axes()

    plt.savefig(consts.SAVE_FOLDER + "compare_parameter_dist_random_models.pdf")
    plt.close()


def parameter_correlation_plot(par_list, fits_info):

    fig = plt.figure(tight_layout=True, figsize=consts.FIG_SIZE_MEDIUM_WIDE)
    gs = gridspec.GridSpec(2, 2, width_ratios=(1, 1), height_ratios=(5, 1), hspace=0.3, wspace=0.05)
    # fig, axes = plt.subplots(1, 2, figsize=consts.FIG_SIZE_MEDIUM_WIDE)

    labels, corr_values, corrected_p_values = parameter_correlations(fits_info)
    par_labels = [parameter_titles[l] for l in labels]
    img = create_correlation_plot(fig.add_subplot(gs[0, 0]), par_labels, corr_values, corrected_p_values,
                                  "Fitted Models", y_label=True)

    rand_labels, rand_corr_values, rand_corrected_p_values = parameter_correlations_from_par_list(par_list)
    par_labels = [parameter_titles[l] for l in rand_labels]
    img = create_correlation_plot(fig.add_subplot(gs[0, 1]), par_labels, rand_corr_values, rand_corrected_p_values * 10e50, "Drawn Models", y_label=False)

    consts.set_figure_labels(xoffset=-2.5, yoffset=1.5)
    fig.label_axes()

    ax_col = fig.add_subplot(gs[1, :])
    data = [np.arange(-1, 1.001, 0.01)] * 10
    ax_col.set_xticks([0, 25, 50, 75, 100, 125, 150, 175, 200])
    ax_col.set_xticklabels([-1, -0.75, -0.5, -0.25, 0, 0.25, 0.5, 0.75, 1])
    ax_col.set_yticks([])
    ax_col.imshow(data)
    ax_col.set_xlabel("Correlation Coefficients")

    plt.savefig(consts.SAVE_FOLDER + "rand_parameter_correlations_comparison.pdf")
    plt.close()


def parameter_correlations_from_par_list(par_list):
    labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength",
              "tau_a", "delta_a", "dend_tau", "refractory_period"]

    parameter_values = {}
    for l in labels:
        parameter_values[l] = []

    for params in par_list:
        for l in labels:
            parameter_values[l].append(params[l])

    corr_values = np.zeros((len(labels), len(labels)))
    p_values = np.ones((len(labels), len(labels)))

    for i in range(len(labels)):
        for j in range(len(labels)):
            c, p = pearsonr(parameter_values[labels[i]], parameter_values[labels[j]])
            corr_values[i, j] = c
            p_values[i, j] = p

    corrected_p_values = p_values * sum(range(len(labels)))

    return labels, corr_values, corrected_p_values


def model_behaviour_distributions(par_list, eodf=800):
    behaviour = {}

    for key in behaviour_titles.keys():
        behaviour[key] = []

    for i, parset in enumerate(par_list):

        model = LifacNoiseModel(parset)
        baseline = BaselineModel(model, eodf)

        behaviour["baseline_frequency"].append(baseline.get_baseline_frequency())
        behaviour["Burstiness"].append(baseline.get_burstiness())
        behaviour["coefficient_of_variation"].append(baseline.get_coefficient_of_variation())
        behaviour["serial_correlation"].append(baseline.get_serial_correlation(1)[0])
        behaviour["vector_strength"].append(baseline.get_vector_strength())

        fi_curve = FICurveModel(model, np.arange(-0.3, 0.301, 0.1), eodf)
        behaviour["f_inf_slope"].append(fi_curve.f_inf_fit[0])
        behaviour["f_zero_slope"].append(fi_curve.get_f_zero_fit_slope_at_straight())

        print("{:} of {:}".format(i + 1, len(par_list)))

    return behaviour


def save_behaviour(behaviour, par_list):
    # save behaviour:
    keys = np.array(sorted(behaviour.keys()))
    data_points = len(behaviour[keys[0]])
    data = np.zeros((len(keys), data_points))
    for i, k in enumerate(keys):
        k_data = np.array(behaviour[k])
        data[i, :] = k_data

    np.save("data/random_model_behaviour_data.npy", data)
    np.save("data/random_model_behaviour_keys.npy", keys)

    # save parameter list:

    par_keys = np.array(sorted(par_list[0].keys()))
    num_models = len(par_list)

    pars_data = np.zeros((num_models, len(par_keys)))

    for i, params in enumerate(par_list):
        params_array = np.array([params[k] for k in par_keys])
        pars_data[i, :] = params_array

    np.save("data/random_model_parameter_data.npy", pars_data)
    np.save("data/random_model_parameter_keys.npy", par_keys)


def load_behavior():
    data = np.load("data/random_model_behaviour_data.npy")
    keys = np.load("data/random_model_behaviour_keys.npy")
    behaviour = {}
    for i, k in enumerate(keys):
        behaviour[k] = data[i, :]

    pars_data = np.load("data/random_model_parameter_data.npy")
    par_keys = np.load("data/random_model_parameter_keys.npy")
    par_list = []

    for i in range(len(pars_data[:, 0])):
        param_dict = {}
        for j, k in enumerate(par_keys):
            param_dict[k] = pars_data[i, j]
        par_list.append(param_dict)

    return behaviour, par_list


def create_behaviour_distributions(drawn_model_behaviour, fits_info):
    fig, axes = plt.subplots(4, 2, gridspec_kw={"left": 0.1, "hspace":0.5}, figsize=consts.FIG_SIZE_LARGE_HIGH)
    cell_behaviour, fitted_model_behaviour = get_behaviour_values(fits_info)
    labels = ['baseline_frequency', 'serial_correlation', 'vector_strength', 'Burstiness', 'coefficient_of_variation', 'f_inf_slope', 'f_zero_slope']
    unit = ["[Hz]", "", "", "[%ms]", "", "[Hz]", "[Hz]"]

    axes_flat = axes.flatten()
    for i, l in enumerate(labels):
        bins = calculate_bins(drawn_model_behaviour[l], 20)
        axes_flat[i].hist(drawn_model_behaviour[l], bins=bins, density=True, color=consts.COLOR_MODEL, alpha=0.75)
        axes_flat[i].hist(cell_behaviour[l], bins=bins, density=True, color=consts.COLOR_DATA, alpha=0.5)
        axes_flat[i].set_xlabel(behaviour_titles[l] + " " + unit[i])
        axes_flat[i].set_yticks([])
        axes_flat[i].set_yticklabels([])
    axes_flat[-1].set_visible(False)

    plt.tight_layout()

    consts.set_figure_labels(xoffset=-2.5, yoffset=0)
    fig.label_axes()
    fig.text(0.03, 0.5, 'Density', ha='center', va='center', rotation='vertical', size=12)  # shared y label

    plt.savefig(consts.SAVE_FOLDER + "random_models_behaviour_dist.pdf")
    plt.close()


def test_plot_models(par_list, eodf):

    for pars in par_list:
        baseline = BaselineModel(LifacNoiseModel(pars), eodf)
        baseline.plot_interspike_interval_histogram()

        fi_curve = FICurveModel(LifacNoiseModel(pars), np.arange(-0.3, 0.31, 0.1), eodf)
        fi_curve.plot_fi_curve()


def calculate_means_and_covariances(param_values):
    transformed_values = {}
    keys = sorted(param_values.keys())
    for key in keys:
        if LOG_TRANSFORM[key]:
            transformed_values[key] = np.log(np.array(param_values[key]))
        else:
            transformed_values[key] = np.array(param_values[key])
    transformed_fits = get_gauss_fits()
    means = np.array([transformed_fits[k][1] for k in keys])

    cov_matrix = np.zeros((len(keys), len(keys)))

    for i, k1 in enumerate(keys):
        for j, k2 in enumerate(keys):
            cor, p = pearsonr(transformed_values[k1], transformed_values[k2])
            cov_matrix[i, j] = cor * transformed_fits[k1][2] * transformed_fits[k2][2]

    return keys, means, cov_matrix


def draw_random_models(num_of_models, keys, means, cov_matrix, seed=None):
    if seed is not None:
        transformed_model_params = multivariate_normal.rvs(means, cov_matrix, num_of_models, seed)
    else:
        transformed_model_params = multivariate_normal.rvs(means, cov_matrix, num_of_models)

    drawn_parameters = []

    for par_set in transformed_model_params:
        retransformed_parameters = {}

        for i, k in enumerate(keys):
            if LOG_TRANSFORM[k]:
                retransformed_parameters[k] = np.exp(par_set[i])
            else:
                retransformed_parameters[k] = par_set[i]

        drawn_parameters.append(retransformed_parameters)

    return drawn_parameters


def get_gauss_fits():
    # TODO NOT NORMED TO INTEGRAL OF 1 !!!!!!!
    transformed_gauss_fits = {}
    # fit parameter: amplitude, mean, sigma
    transformed_gauss_fits["delta_a"] = [0.52555418, -2.17583514, 0.658713652]       # tweak
    transformed_gauss_fits["dend_tau"] = [0.90518987, -5.509343763, 0.3593178]       # good
    transformed_gauss_fits["mem_tau"] = [0.85176348, -6.2468377, 0.42126255]        # good
    transformed_gauss_fits["input_scaling"] = [0.57239028, 5., 0.6]  # [0.37239028, 5.92264105, 1.77342945]  # tweak
    transformed_gauss_fits["noise_strength"] = [0.62216977, -3.49622807, 0.58081673]# good
    transformed_gauss_fits["tau_a"] = [0.82351638, -2.39879173, 0.45725644]        # good
    transformed_gauss_fits["v_offset"] = [1.33749859e-02, -1.91220096e+01, 1.71068108e+01]  # good
    transformed_gauss_fits["refractory_period"] = [1.370406256, 9.14715386e-04, 2.33470418e-04]

    return transformed_gauss_fits


def plot_distributions_with_set_fits(param_values):

    fig, axes = plt.subplots(4, 2, gridspec_kw={"left": 0.1, "hspace":0.5}, figsize=consts.FIG_SIZE_LARGE_HIGH)

    gauss_fits = get_gauss_fits()
    bin_number = 20

    labels = ["input_scaling", "v_offset", "mem_tau", "noise_strength",
              "tau_a", "delta_a", "dend_tau", "refractory_period"]
    x_labels = ["[ln(cm)]", "[mV]", "[ln(s)]", r"[ln(mV$\sqrt{s}$)]", "[ln(s)]", "[ln(mVs)]", "[ln(s)]", "[ms]"]
    for i, key in enumerate(labels):
        k = i % 2
        m = int(i/2)

        values = param_values[key]
        if LOG_TRANSFORM[key]:
            values = np.log(np.array(param_values[key]))

        x = np.arange(min(values), max(values), (max(values) - min(values)) / 100)
        plot_x = np.arange(min(values), max(values), (max(values) - min(values)) / 100)
        if "ms" in x_labels[i]:
            values = np.array(values) * 1000
            plot_x *= 1000

        gauss_param = gauss_fits[key]

        bins = calculate_bins(values, bin_number)

        axes[m, k].hist(values, bins=bins, density=True, alpha=0.75, color=consts.COLOR_MODEL)
        axes[m, k].plot(plot_x, fu.gauss(x, gauss_param[0], gauss_param[1], gauss_param[2]), color="black")
        axes[m, k].set_xlabel(parameter_titles[key] + " " + x_labels[i])
        axes[m, k].set_yticklabels([])
        axes[m, k].set_yticks([])

    plt.tight_layout()

    consts.set_figure_labels(xoffset=-2.5, yoffset=0)
    fig.label_axes()
    fig.text(0.03, 0.5, 'Density', ha='center', va='center', rotation='vertical', size=12)  # shared y label

    plt.savefig(consts.SAVE_FOLDER + "parameter_distribution_with_gauss_fits.pdf")
    plt.close()


def calculate_bins(values, num_of_bins):
    minimum = np.min(values)
    maximum = np.max(values)
    step = (maximum - minimum) / (num_of_bins-1)

    bins = np.arange(minimum-0.5*step, maximum + step, step)
    return bins


def plot_distributions(param_values):
    bin_number = 30
    fig, axes = plt.subplots(len(param_values.keys()), 2)
    for i, key in enumerate(sorted(param_values.keys())):

        # normal hist:
        values = param_values[key]
        normal, n_bins, patches = axes[i, 0].hist(values, bins=calculate_bins(values, bin_number), density=True)
        axes[i, 0].set_title(key)

        # fit gauss:
        bin_width = np.mean(np.diff(n_bins))
        middle_of_bins = n_bins + bin_width / 2
        axes[i, 0].plot(middle_of_bins[:-1], normal, 'o')
        try:
            n_gauss_pars = fit_gauss(middle_of_bins[:-1], normal)
            x = np.arange(min(param_values[key]), max(param_values[key]),
                          (max(param_values[key]) - min(param_values[key])) / 100)
            axes[i, 0].plot(x, fu.gauss(x, n_gauss_pars[0], n_gauss_pars[1], n_gauss_pars[2]))
            print(key, ": normal:", n_gauss_pars)
        except RuntimeError as e:
            pass

        # log transformed:
        if key != "v_offset":
            log_values = np.log(np.array(param_values[key]))
            log_trans, l_bins, patches = axes[i, 1].hist(log_values, bins=bin_number, density=True)
            bin_width = np.mean(np.diff(l_bins))
            middle_of_bins = l_bins + bin_width / 2
            axes[i, 1].plot(middle_of_bins[:-1], log_trans, 'o')
            try:
                l_gauss_pars = fit_gauss(middle_of_bins[:-1], log_trans)
                x = np.arange(min(log_values), max(log_values),
                              (max(log_values) - min(log_values)) / 100)
                axes[i, 1].plot(x, fu.gauss(x, l_gauss_pars[0], l_gauss_pars[1], l_gauss_pars[2]))
                print(key, ": log:", l_gauss_pars)
            except RuntimeError as e:
                pass

    plt.tight_layout()
    plt.show()
    plt.close()


def fit_gauss(x_values, y_values):
    mean_v = np.mean(x_values)
    std_v = np.std(x_values)
    amp = max(y_values)
    popt, pcov = curve_fit(fu.gauss, x_values, y_values, p0=(amp, mean_v, std_v))

    return popt

def get_parameter_distributions(folder, param_keys=None):
    if param_keys is None:
        param_keys = ["v_offset", 'input_scaling', 'dend_tau', 'tau_a', 'delta_a',
                      'refractory_period', 'noise_strength', 'mem_tau']
    parameter_values = {}

    for key in param_keys:
        parameter_values[key] = []

    for cell in sorted(os.listdir(folder)):
        fit = get_best_fit(folder + cell)

        final_params = fit.get_final_parameters()

        for key in param_keys:
            parameter_values[key].append(final_params[key])

    return parameter_values


if __name__ == '__main__':
    main()