import argparse
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr

from ModelFit import get_best_fit
from Baseline import BaselineModel
from FiCurve import FICurveModel
from CellData import CellData
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus


def main():
    # parser = argparse.ArgumentParser()
    # parser.add_argument("dir", help="folder containing the cell folders with the fit results")
    # args = parser.parse_args()

    dir_path = "results/invivo_results/"  # args.dir
    # dir_path = "results/results_add__trial_more_iter_NM/invivo_results"  # args.dir

    # if not os.path.isdir(dir_path):
    #     print("Argument dir is not a directory.")
    #     parser.print_usage()
    #     exit(0)
    # sensitivity_analysis(dir_path, max_models=3)

    fits_info = get_fit_info(dir_path)

    errors = calculate_percent_errors(fits_info)
    create_boxplots(errors)
    labels, corr_values, corrected_p_values = behaviour_correlations(fits_info, model_values=False)
    create_correlation_plot(labels, corr_values, corrected_p_values)

    labels, corr_values, corrected_p_values = parameter_correlations(fits_info)
    create_correlation_plot(labels, corr_values, corrected_p_values)

    create_parameter_distributions(get_parameter_values(fits_info))
    cell_b, model_b = get_behaviour_values(fits_info)
    create_behaviour_distributions(cell_b, model_b)
    pass


def calculate_percent_errors(fits_info):
    errors = {}

    for cell in sorted(fits_info.keys()):
        for behaviour in fits_info[cell][1].keys():
            if behaviour not in errors.keys():
                errors[behaviour] = []

            if fits_info[cell][2][behaviour] == 0:
                if fits_info[cell][1][behaviour] == 0:
                    errors[behaviour].append(0)
                else:
                    print("Cannot calc % error if reference is 0")
                continue
            errors[behaviour].append((fits_info[cell][1][behaviour] - fits_info[cell][2][behaviour]) / fits_info[cell][2][behaviour])
    return errors


def get_parameter_values(fits_info):
    par_keys = sorted(["input_scaling", "delta_a", "mem_tau", "noise_strength",
                       "refractory_period", "tau_a", "v_offset", "dend_tau"])
    parameter_values = {}
    for cell in sorted(fits_info.keys()):
        for par in par_keys:
            if par not in parameter_values.keys():
                parameter_values[par] = []

            parameter_values[par].append(fits_info[cell][0][par])

    return parameter_values


def get_behaviour_values(fits_info):
    behaviour_values_cell = {}
    behaviour_values_model = {}
    for cell in sorted(fits_info.keys()):
        for behaviour in fits_info[cell][1].keys():
            if behaviour not in behaviour_values_cell.keys():
                behaviour_values_cell[behaviour] = []
                behaviour_values_model[behaviour] = []

            behaviour_values_model[behaviour].append(fits_info[cell][1][behaviour])
            behaviour_values_cell[behaviour].append(fits_info[cell][2][behaviour])

    return behaviour_values_cell, behaviour_values_model


def behaviour_correlations(fits_info, model_values=True):
    bv_cell, bv_model = get_behaviour_values(fits_info)
    if model_values:
        behaviour_values = bv_model
    else:
        behaviour_values = bv_cell

    labels = sorted(behaviour_values.keys())
    corr_values = np.zeros((len(labels), len(labels)))
    p_values = np.ones((len(labels), len(labels)))

    for i in range(len(labels)):
        for j in range(len(labels)):
            c, p = pearsonr(behaviour_values[labels[i]], behaviour_values[labels[j]])
            corr_values[i, j] = c
            p_values[i, j] = p

    corrected_p_values = p_values * sum(range(len(labels)))

    return labels, corr_values, corrected_p_values


def parameter_correlations(fits_info):
    parameter_values = get_parameter_values(fits_info)

    labels = sorted(parameter_values.keys())
    corr_values = np.zeros((len(labels), len(labels)))
    p_values = np.ones((len(labels), len(labels)))

    for i in range(len(labels)):
        for j in range(len(labels)):
            c, p = pearsonr(parameter_values[labels[i]], parameter_values[labels[j]])
            corr_values[i, j] = c
            p_values[i, j] = p

    corrected_p_values = p_values * sum(range(len(labels)))

    return labels, corr_values, corrected_p_values


def get_fit_info(folder):
    fits_info = {}

    for item in os.listdir(folder):
        cell_folder = os.path.join(folder, item)

        results = get_best_fit(cell_folder)
        cell_behaviour, model_behaviour = results.get_behaviour_values()
        fits_info[item] = [results.get_final_parameters(), model_behaviour, cell_behaviour]

    return fits_info


def create_correlation_plot(labels, correlations, p_values):

    cleaned_cors = np.zeros(correlations.shape)

    for i in range(correlations.shape[0]):
        for j in range(correlations.shape[1]):
            if abs(p_values[i, j]) < 0.05:
                cleaned_cors[i, j] = correlations[i, j]

    fig, ax = plt.subplots()
    im = ax.imshow(cleaned_cors, vmin=-1, vmax=1)

    cbar = ax.figure.colorbar(im, ax=ax)
    cbar.ax.set_ylabel("Correlation coefficient", rotation=-90, va="bottom")

    # We want to show all ticks...
    ax.set_xticks(np.arange(len(labels)))
    ax.set_yticks(np.arange(len(labels)))
    # ... and label them with the respective list entries
    ax.set_xticklabels(labels)
    ax.set_yticklabels(labels)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # Loop over data dimensions and create text annotations.
    for i in range(len(labels)):
        for j in range(len(labels)):
            text = ax.text(j, i, "{:.2f}".format(correlations[i, j]),
                           ha="center", va="center", color="w")

    fig.tight_layout()
    plt.show()


def create_boxplots(errors):

    labels = ["{}_n:{}".format(k, len(errors[k])) for k in sorted(errors.keys())]

    y_values = [errors[k] for k in sorted(errors.keys())]

    plt.boxplot(y_values)
    plt.xticks(np.arange(1, len(y_values)+1, 1), labels, rotation=45)
    plt.tight_layout()
    plt.show()
    plt.close()


def create_parameter_distributions(par_values):

    fig, axes = plt.subplots(4, 2)

    if len(par_values.keys()) != 8:
        print("not eight parameters")

    labels = sorted(par_values.keys())
    axes_flat = axes.flatten()
    for i, l in enumerate(labels):
        min_v = min(par_values[l]) * 0.95
        max_v = max(par_values[l]) * 1.05
        step = (max_v - min_v) / 15
        bins = np.arange(min_v, max_v+step, step)
        axes_flat[i].hist(par_values[l], bins=bins)
        axes_flat[i].set_title(l)

    plt.tight_layout()
    plt.show()
    plt.close()


def create_behaviour_distributions(cell_b_values, model_b_values):
    fig, axes = plt.subplots(4, 2)

    labels = sorted(cell_b_values.keys())
    axes_flat = axes.flatten()
    for i, l in enumerate(labels):
        min_v = min(min(cell_b_values[l]), min(model_b_values[l])) * 0.95
        max_v = max(max(cell_b_values[l]), max(model_b_values[l])) * 1.05
        limit = 50000
        if max_v > limit:
            print("For {} the max value was limited to {},  {} values were excluded!".format(l, limit, np.sum(np.array(cell_b_values[l]) > limit)))
            max_v = limit
        step = (max_v - min_v) / 15
        bins = np.arange(min_v, max_v + step, step)
        axes_flat[i].hist(cell_b_values[l], bins=bins, alpha=0.5)
        axes_flat[i].hist(model_b_values[l], bins=bins, alpha=0.5)
        axes_flat[i].set_title(l)

    plt.tight_layout()
    plt.show()
    plt.close()

    pass


def sensitivity_analysis(dir_path, par_range=(0.5, 1.6, 0.1), contrast_range=(-0.3, 0.4, 0.1), parameters=None, behaviours=None, max_models=None):
    models = []
    eods = []
    base_freqs = []

    count = 0
    for item in sorted(os.listdir(dir_path)):
        count += 1
        if max_models is not None and count > max_models:
            break
        cell_folder = os.path.join(dir_path, item)

        results = get_best_fit(cell_folder)

        models.append(results.get_model())
        eods.append(CellData(results.get_cell_path()).get_eod_frequency())
        cell, model = results.get_behaviour_values()
        base_freqs.append(cell["baseline_frequency"])

    if parameters is None:
        parameters = ["input_scaling", "delta_a", "mem_tau", "noise_strength",
                      "refractory_period", "tau_a", "dend_tau"]

    if behaviours is None:
        behaviours = ["burstiness", "coefficient_of_variation", "serial_correlation",
                      "vector_strength", "f_inf_slope", "f_zero_slope", "f_zero_middle"]

    model_behaviour_responses = []

    contrasts = np.arange(contrast_range[0], contrast_range[1], contrast_range[2])
    factors = np.arange(par_range[0], par_range[1], par_range[2])
    for model, eod, base_freq in zip(models, eods, base_freqs):
        par_responses = {}
        for par in parameters:
            par_responses[par] = {}
            for b in behaviours:
                par_responses[par][b] = np.zeros(len(factors))
            for i, factor in enumerate(factors):

                model_copy = model.get_model_copy()
                model_copy.set_variable(par, model.get_parameters()[par] * factor)
                print("{} at {}, ({} of {})".format(par, model.get_parameters()[par] * factor, i+1, len(factors)))
                base_stimulus = SinusoidalStepStimulus(eod, 0)
                v_offset = model_copy.find_v_offset(base_freq, base_stimulus)
                model_copy.set_variable("v_offset", v_offset)

                baseline = BaselineModel(model_copy, eod, trials=3)
                print(baseline.get_baseline_frequency())
                if "burstiness" in behaviours:
                    par_responses[par]["burstiness"][i] = baseline.get_burstiness()
                if "coefficient_of_variation" in behaviours:
                    par_responses[par]["coefficient_of_variation"][i] = baseline.get_coefficient_of_variation()
                if "serial_correlation" in behaviours:
                    par_responses[par]["serial_correlation"][i] = baseline.get_serial_correlation(1)[0]
                if "vector_strength" in behaviours:
                    par_responses[par]["vector_strength"][i] = baseline.get_vector_strength()

                fi_curve = FICurveModel(model_copy, contrasts, eod, trials=20)

                if "f_inf_slope" in behaviours:
                    par_responses[par]["f_inf_slope"][i] = fi_curve.get_f_inf_slope()
                if "f_zero_slope" in behaviours:
                    par_responses[par]["f_zero_slope"][i] = fi_curve.get_f_zero_fit_slope_at_straight()
                if "f_zero_middle" in behaviours:
                    par_responses[par]["f_zero_middle"][i] = fi_curve.f_zero_fit[3]

        model_behaviour_responses.append(par_responses)

    print("sensitivity analysis done!")
    plot_sensitivity_analysis(model_behaviour_responses, behaviours, parameters, factors)


def plot_sensitivity_analysis(responses, behaviours, parameters, factors):
    fig, axes = plt.subplots(len(behaviours), len(parameters), sharex="all", sharey="row", figsize=(8, 8))
    for i, behaviour in enumerate(behaviours):
        for j, par in enumerate(parameters):

            for model in responses:
                axes[i, j].plot(factors, model[par][behaviour])

            if j == 0:
                axes[i, j].set_ylabel("{}".format(behaviour))
            if i == 0:
                axes[i, j].set_title("{}".format(par))

    plt.tight_layout()
    plt.show()
    plt.close()


if __name__ == '__main__':
    main()