import os
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
from Baseline import get_baseline_class
from FiCurve import get_fi_curve_class
from CellData import CellData
import helperFunctions as hF
import numpy as np
import functions as fu
import matplotlib.pyplot as plt


def get_best_fit(folder_path, use_comparable_error=True):
    min_err = np.inf
    min_item = ""
    for item in os.listdir(folder_path):
        item_path = os.path.join(folder_path, item)
        if use_comparable_error:
            err = ModelFit(item_path).comparable_error()
        else:
            err = ModelFit(item_path).get_fit_routine_error()
        if err < min_err:
            min_err = err
            min_item = item

    return ModelFit(os.path.join(folder_path, min_item))


class ModelFit:

    def __init__(self, folder_path):
        self.path = folder_path
        self.parameter_file_name = "parameters_info.txt"
        self.value_file = "value_comparision.tsv"
        self.fi_comp_img = "fi_curve_comparision.png"
        self.isi_hist_img = "isi-histogram.png"
        self.isi_hist_comp_img = "isi-histogram_comparision.png"

        self.model_f_inf_file = "model_fi_inf_values.npy"
        self.cell_f_inf_file = "cell_fi_inf_values.npy"
        self.model_f_zero_file = "model_fi_zero_values.npy"
        self.cell_f_zero_file = "cell_fi_zero_values.npy"

    def get_final_parameters(self):
        par_file_path = os.path.join(self.path, self.parameter_file_name)
        with open(par_file_path, 'r') as par_file:
            for line in par_file:
                line = line.strip().split('\t')

                if line[0] == "final_parameters:":
                    return eval(line[1])

        print("Final parameters not found! - ", self.path)
        return {}

    def get_start_parameters(self):
        par_file_path = os.path.join(self.path, self.parameter_file_name)
        with open(par_file_path, 'r') as par_file:
            for line in par_file:
                line = line.strip().split('\t')

                if line[0] == "start_parameters:":
                    return dict(line[1])
        print("Start parameters not found! - ", self.path)
        return {}

    def get_behaviour_values(self):
        values_file_path = os.path.join(self.path, self.value_file)
        cell_values = {}
        model_values = {}
        with open(values_file_path, 'r') as val_file:
            line = val_file.readline()  # ignore headers
            for line in val_file:
                line = line.strip().split('\t')
                cell_values[line[0]] = float(line[1])
                model_values[line[0]] = float(line[2])

        return cell_values, model_values

    def get_fi_curve_comparision_image(self):
        path = os.path.join(self.path, self.fi_comp_img)
        if os.path.exists(path):
            return path
        else:
            raise FileNotFoundError("Fi-curve comparision image is missing. - " + self.path)

    def get_isi_histogram_image(self):
        path = os.path.join(self.path, self.isi_hist_img)
        if os.path.exists(path):
            return path
        else:
            raise FileNotFoundError("Isi-histogram image is missing. - " + self.path)

    def get_error_value(self):
        return self.path.split("_")[-1]

    def get_model(self):
        return LifacNoiseModel(self.get_final_parameters())

    def get_cell_path(self):
        with open(os.path.join(self.path, "cell_data_path.txt"), "r") as f:
            cell_path = f.readline().strip()

            return cell_path

    def get_cell_data(self):
        return CellData(self.get_cell_path())

    def get_model_f_inf_values(self):
        path = os.path.join(self.path, self.model_f_inf_file)
        return np.load(path)

    def get_model_f_zero_values(self):
        path = os.path.join(self.path, self.model_f_zero_file)
        return np.load(path)

    def get_cell_f_inf_values(self):
        path = os.path.join(self.path, self.cell_f_inf_file)
        return np.load(path)

    def get_cell_f_zero_values(self):
        path = os.path.join(self.path, self.cell_f_zero_file)
        return np.load(path)

    def get_fit_routine_error(self):
        foldername = os.path.basename(self.path)
        parts = foldername.split("_")
        return float(parts[-1])

    def comparable_error(self):
        cell_values, model_values = self.get_behaviour_values()

        error = 0

        bf = "baseline_frequency"
        error += abs(cell_values[bf] - model_values[bf]) / 5
        vs = "vector_strength"
        error += abs(cell_values[vs] - model_values[vs]) / 0.1
        sc = "serial_correlation"
        error += abs(cell_values[sc] - model_values[sc]) / 0.1
        burst = "Burstiness"
        error += abs(cell_values[burst] - model_values[burst]) / 0.05
        cv = "coefficient_of_variation"
        error += abs(cell_values[cv] - model_values[cv]) / 0.1
        f_inf_slope = "f_inf_slope"
        error += abs(cell_values[f_inf_slope] - model_values[f_inf_slope]) / 5

        # f_zero_sloe = "f_zero_slope"
        # error += abs(cell_values[f_zero_sloe] - model_values[f_zero_sloe]) / 100

        c_f_inf_values = self.get_cell_f_inf_values()
        c_f_zero_values = self.get_cell_f_zero_values()

        m_f_inf_values = self.get_model_f_inf_values()
        m_f_zero_values = self.get_cell_f_zero_values()

        error_f_inf = 0
        for m_value, c_value in zip(m_f_inf_values, c_f_inf_values):
            error_f_inf += abs(c_value - m_value) / 10

        error_f_inf = error_f_inf / len(m_f_inf_values)
        error += error_f_inf

        error_f_zero = 0
        for m_value, c_value in zip(m_f_zero_values, c_f_zero_values):
            error_f_zero += abs(c_value - m_value) / 10

        error_f_zero = error_f_zero / len(m_f_zero_values)
        error += error_f_zero

        return error

    def generate_master_plot(self, save_path=None):
        model = self.get_model()
        cell = self.get_cell_data()

        fig, axes = plt.subplots(4, 1, figsize=(8, 12))
        # isi histogram:
        axes[0].set_title("ISI-Histogram")
        axes[0].set_xlim((0, 50))
        bins = np.arange(0, 50, 0.1)
        for data, name in zip((cell, model), ("cell", "model")):
            base = get_baseline_class(data, cell.get_eod_frequency(), trials=5)
            isis = np.array(base.get_interspike_intervals()) * 1000
            axes[0].hist(isis, bins=bins, label=name, alpha=0.5, density=True)

        axes[0].legend()

        # fi_curve

        fi_curve = get_fi_curve_class(cell, cell.get_fi_contrasts(), save_dir=cell.get_data_path())
        f_inf_slope = fi_curve.get_f_inf_slope()
        contrasts = np.array(fi_curve.stimulus_values)
        if f_inf_slope < 0:
            contrasts = contrasts * -1
            fi_curve_cell = get_fi_curve_class(cell, contrasts)
            print("cell: {} , FI-Curve has saved contrasts that give negative f_inf slope!".format(cell.get_data_path()))
        else:
            fi_curve_cell = fi_curve

        fi_curve_model = get_fi_curve_class(model, contrasts, eod_freq=cell.get_eod_frequency(), trials=15)

        axes[1].set_title("Fi-Curve")
        min_x = min(min(fi_curve_cell.stimulus_values), min(fi_curve_model.stimulus_values))
        max_x = max(max(fi_curve_cell.stimulus_values), max(fi_curve_model.stimulus_values))
        step = (max_x - min_x) / 5000
        x_values = np.arange(min_x, max_x + step, step)

        # plot baseline
        f_base_color = ("blue", "deepskyblue")
        f_inf_color = ("green", "limegreen")
        f_zero_color = ("red", "orange")

        median_baseline = np.median(fi_curve_cell.get_f_baseline_frequencies())
        axes[1].plot((min_x, max_x), (median_baseline, median_baseline), color=f_base_color[0], label="cell med base")
        axes[1].plot(fi_curve_model.stimulus_values, fi_curve_model.get_f_baseline_frequencies(),
                     'o', color=f_base_color[1], label='model base')

        y_values = [fu.clipped_line(x, fi_curve_cell.f_inf_fit[0], fi_curve_cell.f_inf_fit[1]) for x in x_values]
        axes[1].plot(x_values, y_values, color=f_inf_color[0], label='f_inf_fit cell')
        axes[1].plot(fi_curve_model.stimulus_values, fi_curve_model.get_f_inf_frequencies(),
                     'o', color=f_inf_color[1], label='f_inf model')

        popt = fi_curve_cell.f_zero_fit
        axes[1].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                     color=f_zero_color[0], label='f_0_fit cell')
        axes[1].plot(fi_curve_model.stimulus_values, fi_curve_model.get_f_zero_frequencies(),
                     'o', color=f_zero_color[1], label='f_zero model')
        axes[1].set_title("cell model comparision")
        axes[1].set_xlabel("Stimulus value - contrast")
        axes[1].legend()
        # comparision of f_zero_curve:

        max_contrast = max(contrasts)
        test_contrast = 0.5 * max_contrast
        diff_contrasts = np.abs(contrasts - test_contrast)
        f_zero_curve_contrast_idx = np.argmin(diff_contrasts)

        # model:
        stimulus = SinusoidalStepStimulus(cell.get_eod_frequency(), contrasts[f_zero_curve_contrast_idx],
                                         start_time=0, duration=cell.get_stimulus_duration())
        freq_traces = []
        time_traces = []
        for i in range(10):
            v1, spikes = model.simulate(stimulus, cell.get_time_end() - cell.get_time_start(), cell.get_time_start())
            time, freq = hF.calculate_time_and_frequency_trace(spikes, model.get_sampling_interval())
            freq_traces.append(freq)
            time_traces.append(time)

        time, freq = hF.calculate_mean_of_frequency_traces(time_traces, freq_traces, model.get_sampling_interval())

        cell_times, cell_freqs = fi_curve_cell.get_mean_time_and_freq_traces()
        axes[2].plot(cell_times[f_zero_curve_contrast_idx], cell_freqs[f_zero_curve_contrast_idx])
        axes[2].plot(np.array(time) + 0.005, freq)
        axes[2].set_title("blue: cell, orange: model")
        axes[2].set_xlim(-0.15, 0.35)

        start_idx = -1
        end_idx = -1
        for idx in range(len(cell_times[f_zero_curve_contrast_idx])):
            if cell_times[f_zero_curve_contrast_idx][idx] < -0.15:
                start_idx = idx
            elif cell_times[f_zero_curve_contrast_idx][idx] > 0.35:
                end_idx = idx
                break
        axes[2].set_ylim(0.9*min(cell_freqs[f_zero_curve_contrast_idx][start_idx:end_idx]),
                         1.1*max(cell_freqs[f_zero_curve_contrast_idx][start_idx:end_idx]))
        # Value table:
        cell_values, model_values = self.get_behaviour_values()

        collabel = sorted(cell_values.keys())
        clust_data = [[], []]
        for k in collabel:
            clust_data[0].append(cell_values[k])
            clust_data[1].append(model_values[k])

        axes[3].axis('tight')
        axes[3].axis('off')
        table = axes[3].table(cellText=clust_data, colLabels=collabel, rowLabels=("cell", "model"), loc='center')
        fig.suptitle(cell.get_cell_name() + "_comp_err: {:.2f}".format(self.comparable_error()))
        plt.tight_layout()
        if save_path is None:
            plt.show()
        else:
            plt.savefig(save_path + cell.get_cell_name() + "_master_plot.pdf")