from parser.CellData import CellData
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import numpy as np
import matplotlib.pyplot as plt
from warnings import warn
from my_util import helperFunctions as hF, functions as fu
from os.path import join, exists
import pickle
from sys import stderr


class FICurve:

    def __init__(self, stimulus_values, save_dir=None, recalculate=False):
        self.save_file_name = "fi_curve_values.pkl"
        self.stimulus_values = stimulus_values

        self.indices_f_baseline = []
        self.f_baseline_frequencies = []
        self.indices_f_inf = []
        self.f_inf_frequencies = []
        self.indices_f_zero = []
        self.f_zero_frequencies = []

        # increase, offset
        self.f_inf_fit = []
        # f_max, f_min, k, x_zero
        self.f_zero_fit = []

        if save_dir is None:
            self.initialize()
        else:
            if recalculate:
                self.initialize()
                self.save_values(save_dir)
            else:
                if not self.load_values(save_dir):
                    self.initialize()
                    self.save_values(save_dir)

    def initialize(self):
        self.calculate_all_frequency_points()
        self.f_inf_fit = hF.fit_clipped_line(self.stimulus_values, self.f_inf_frequencies)
        self.f_zero_fit = hF.fit_boltzmann(self.stimulus_values, self.f_zero_frequencies)

    def calculate_all_frequency_points(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_f_baseline_frequencies(self):
        return self.f_baseline_frequencies

    def get_f_inf_frequencies(self):
        return self.f_inf_frequencies

    def get_f_zero_frequencies(self):
        return self.f_zero_frequencies

    def get_f_inf_slope(self):
        if len(self.f_inf_fit) > 0:
            return self.f_inf_fit[0]

    def get_f_zero_fit_slope_at_straight(self):
        fit_vars = self.f_zero_fit
        return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

    def get_f_zero_fit_slope_at_stimulus_value(self, stimulus_value):
        fit_vars = self.f_zero_fit
        return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

    def get_f_inf_frequency_at_stimulus_value(self, stimulus_value):
        return fu.clipped_line(stimulus_value, self.f_inf_fit[0], self.f_inf_fit[1])

    def get_f_zero_and_f_inf_intersection(self):
        x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001)
        fit_vars = self.f_zero_fit
        f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
        f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1])

        intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten()
        # print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies])
        if len(intersection_indicies) > 1:
            f_baseline = np.median(self.f_baseline_frequencies)
            best_dist = np.inf
            best_idx = -1
            for idx in intersection_indicies:
                dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline)
                if dist < best_dist:
                    best_dist = dist
                    best_idx = idx

            return x_values[best_idx]

        elif len(intersection_indicies) == 0:
            raise ValueError("No intersection found!")
        else:
            return x_values[intersection_indicies[0]]

    def get_f_zero_fit_slope_at_f_inf_fit_intersection(self):
        x = self.get_f_zero_and_f_inf_intersection()
        fit_vars = self.f_zero_fit
        return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

    def get_mean_time_and_freq_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_time_and_freq_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_sampling_interval(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_delay(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_stimulus_start(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_stimulus_end(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_stimulus_duration(self):
        return self.get_stimulus_end() - self.get_stimulus_start()

    def plot_mean_frequency_curves(self, save_path=None):
        time_traces, freq_traces = self.get_time_and_freq_traces()
        mean_times, mean_freqs = self.get_mean_time_and_freq_traces()
        for i, sv in enumerate(self.stimulus_values):

            for j in range(len(time_traces[i])):
                plt.plot(time_traces[i][j], freq_traces[i][j], color="gray", alpha=0.8)

            plt.plot(mean_times[i], mean_freqs[i], color="black")
            plt.xlabel("Time [s]")
            plt.ylabel("Frequency [Hz]")


            plt.title("Mean frequency at contrast {:.2f} ({:} trials)".format(sv, len(time_traces[i])))
            if save_path is None:
                plt.show()
            else:
                plt.savefig(save_path + "mean_frequency_contrast_{:.2f}.png".format(sv))
            plt.close()

    def plot_fi_curve(self, save_path=None):
        min_x = min(self.stimulus_values)
        max_x = max(self.stimulus_values)
        step = (max_x - min_x) / 5000
        x_values = np.arange(min_x, max_x, step)

        plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base')

        plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf')
        plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values],
                 color='darkgreen', label='f_inf_fit')

        plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero')
        popt = self.f_zero_fit
        plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                 color='red', label='f_0_fit')

        plt.legend()
        plt.ylabel("Frequency [Hz]")
        plt.xlabel("Stimulus value")

        if save_path is None:
            plt.show()
        else:
            plt.savefig(save_path + "fi_curve.png")
        plt.close()

    @staticmethod
    def plot_fi_curve_comparision(data_fi_curve, model_fi_curve, save_path=None):
        min_x = min(min(data_fi_curve.stimulus_values), min(model_fi_curve.stimulus_values))
        max_x = max(max(data_fi_curve.stimulus_values), max(model_fi_curve.stimulus_values))
        step = (max_x - min_x) / 5000
        x_values = np.arange(min_x, max_x+step, step)

        fig, axes = plt.subplots(1, 3, sharex="all", sharey='all', figsize=(15, 6))
        # plot baseline
        data_origin = (data_fi_curve, model_fi_curve)
        f_base_color = ("blue", "deepskyblue")
        f_inf_color = ("green", "limegreen")
        f_zero_color = ("red", "orange")
        for i in range(2):

            axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_baseline_frequencies(),
                         color=f_base_color[i], label='f_base')

            axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_inf_frequencies(),
                         'o', color=f_inf_color[i], label='f_inf')
            y_values = [fu.clipped_line(x, data_origin[i].f_inf_fit[0], data_origin[i].f_inf_fit[1]) for x in x_values]
            axes[i].plot(x_values, y_values, color=f_inf_color[i], label='f_inf_fit')

            axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_zero_frequencies(),
                         'o', color=f_zero_color[i], label='f_zero')
            popt = data_origin[i].f_zero_fit
            axes[i].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                         color=f_zero_color[i], label='f_0_fit')

            axes[i].set_xlabel("Stimulus value - contrast")
            axes[i].legend()

        axes[0].set_title("cell")
        axes[0].set_ylabel("Frequency [Hz]")
        axes[1].set_title("model")

        median_baseline = np.median(data_fi_curve.get_f_baseline_frequencies())
        axes[2].plot((min_x, max_x), (median_baseline, median_baseline), color=f_base_color[0], label="cell med base")
        axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_baseline_frequencies(),
                     'o', color=f_base_color[1], label='model base')

        y_values = [fu.clipped_line(x, data_fi_curve.f_inf_fit[0], data_fi_curve.f_inf_fit[1]) for x in x_values]
        axes[2].plot(x_values, y_values, color=f_inf_color[0], label='f_inf_fit cell')
        axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_inf_frequencies(),
                     'o', color=f_inf_color[1], label='f_inf model')

        popt = data_fi_curve.f_zero_fit
        axes[2].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                     color=f_zero_color[0], label='f_0_fit cell')
        axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_zero_frequencies(),
                     'o', color=f_zero_color[1], label='f_zero model')
        axes[2].set_title("cell model comparision")
        axes[2].set_xlabel("Stimulus value - contrast")
        axes[2].legend()

        if save_path is None:
            plt.show()
        else:
            plt.savefig(save_path + "fi_curve_comparision.png")
        plt.close()

    def plot_f_point_detections(self, save_path=None):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def save_values(self, save_directory):

        values = {}
        values["stimulus_values"] = self.stimulus_values
        values["f_baseline_frequencies"] = self.f_baseline_frequencies
        values["f_inf_frequencies"] = self.f_inf_frequencies
        values["f_zero_frequencies"] = self.f_zero_frequencies
        values["f_inf_fit"] = self.f_inf_fit
        values["f_zero_fit"] = self.f_zero_fit

        with open(join(save_directory, self.save_file_name), "wb") as file:
            pickle.dump(values, file)

        print("Fi-Curve: Values saved!")

    def load_values(self, save_directory):
        file_path = join(save_directory, self.save_file_name)
        if not exists(file_path):
            print("Fi-Curve: No file to load")
            return False

        file = open(file_path, "rb")
        values = pickle.load(file)
        if set(values["stimulus_values"]) != set(self.stimulus_values):
            stderr.write("Fi-Curve:load_values() - Given stimulus values are different to the loaded ones!:\n "
                         "given: {}\n loaded: {}\n".format(str(self.stimulus_values), str(values["stimulus_values"])))

        self.stimulus_values = values["stimulus_values"]
        self.f_baseline_frequencies = values["f_baseline_frequencies"]
        self.f_inf_frequencies = values["f_inf_frequencies"]
        self.f_zero_frequencies = values["f_zero_frequencies"]
        self.f_inf_fit = values["f_inf_fit"]
        self.f_zero_fit = values["f_zero_fit"]
        print("Fi-Curve: Values loaded!")
        return True


class FICurveCellData(FICurve):

    def __init__(self, cell_data: CellData, stimulus_values, save_dir=None, recalculate=False):
        self.cell_data = cell_data
        super().__init__(stimulus_values, save_dir, recalculate)

    def calculate_all_frequency_points(self):
        mean_frequencies = self.cell_data.get_mean_fi_curve_isi_frequencies()
        time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies()
        stimulus_start = self.cell_data.get_stimulus_start()
        stimulus_duration = self.cell_data.get_stimulus_duration()
        sampling_interval = self.cell_data.get_sampling_interval()

        if len(mean_frequencies) == 0:
            warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n"
                 "Was all_calculate_mean_isi_frequencies already called?")

        for i in range(len(mean_frequencies)):

            if time_axes[i][0] > self.cell_data.get_stimulus_start():
                raise ValueError("TODO: Deal with to strongly cut frequency traces in cell data! ")
                # self.f_zero_frequencies.append(-1)
                # self.f_baseline_frequencies.append(-1)
                # self.f_inf_frequencies.append(-1)
                # continue

            f_zero, f_zero_idx = hF.detect_f_zero_in_frequency_trace(time_axes[i], mean_frequencies[i],
                                                         stimulus_start, sampling_interval)
            self.f_zero_frequencies.append(f_zero)
            self.indices_f_zero.append(f_zero_idx)

            f_baseline, f_base_idx = hF.detect_f_baseline_in_freq_trace(time_axes[i], mean_frequencies[i],
                                                            stimulus_start, sampling_interval)
            self.f_baseline_frequencies.append(f_baseline)
            self.indices_f_baseline.append(f_base_idx)
            f_infinity, f_inf_idx = hF.detect_f_infinity_in_freq_trace(time_axes[i], mean_frequencies[i],
                                                            stimulus_start, stimulus_duration, sampling_interval)
            self.f_inf_frequencies.append(f_infinity)
            self.indices_f_inf.append(f_inf_idx)

    def get_mean_time_and_freq_traces(self):
        return self.cell_data.get_time_axes_fi_curve_mean_frequencies(), self.cell_data.get_mean_fi_curve_isi_frequencies()

    def get_time_and_freq_traces(self):
        spiketimes = self.cell_data.get_fi_spiketimes()
        time_traces = []
        freq_traces = []
        for i in range(len(spiketimes)):
            trial_time_traces = []
            trial_freq_traces = []
            for j in range(len(spiketimes[i])):
                time, isi_freq = hF.calculate_time_and_frequency_trace(spiketimes[i][j], self.cell_data.get_sampling_interval())

                trial_freq_traces.append(isi_freq)
                trial_time_traces.append(time)

            time_traces.append(trial_time_traces)
            freq_traces.append(trial_freq_traces)

        return time_traces, freq_traces

    def get_sampling_interval(self):
        return self.cell_data.get_sampling_interval()

    def get_delay(self):
        return self.cell_data.get_delay()

    def get_stimulus_start(self):
        return self.cell_data.get_stimulus_start()

    def get_stimulus_end(self):
        return self.cell_data.get_stimulus_end()

    def get_f_zero_inverse_at_frequency(self, frequency):
        # UNUSED
        b_vars = self.f_zero_fit
        return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3])

    def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value):
        # UNUSED
        infty_vars = self.f_inf_fit
        return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1])

    def plot_f_point_detections(self, save_path=None):
        mean_frequencies = np.array(self.cell_data.get_mean_fi_curve_isi_frequencies())
        time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies()
        sampling_interval = self.cell_data.get_sampling_interval()
        stim_start = self.cell_data.get_stimulus_start()
        stim_duration = self.cell_data.get_stimulus_duration()

        for i, c in enumerate(self.stimulus_values):
            time = time_axes[i]
            frequency = mean_frequencies[i]

            if len(time) == 0 or min(time) > stim_start \
                    or max(time) < stim_start + stim_duration:
                continue
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
            ax.set_title("Stimulus value: {:.2f}".format(c))
            ax.plot(time, frequency)
            start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
                    label="f_base", color="deepskyblue")

            start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], stim_start, stim_duration,
                                                                  sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
                    label="f_inf", color="limegreen")

            start_idx, end_idx = hF.time_window_detect_f_zero(time[0], stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
                    label="f_zero", color="orange")

            plt.legend()
            if save_path is not None:
                plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
            else:
                plt.show()

            plt.close()


class FICurveModel(FICurve):
    stim_duration = 0.5
    stim_start = 0.5
    total_simulation_time = stim_duration + 2 * stim_start

    def __init__(self, model, stimulus_values, eod_frequency, trials=5):
        self.eod_frequency = eod_frequency
        self.model = model
        self.trials = trials
        self.spiketimes_array = np.zeros((len(stimulus_values), trials), dtype=list)
        self.mean_frequency_traces = []
        self.mean_time_traces = []
        self.set_model_adaption_to_baseline()
        super().__init__(stimulus_values)

    def set_model_adaption_to_baseline(self):
        stimulus = SinusoidalStepStimulus(self.eod_frequency, 0, 0, 0)
        self.model.simulate(stimulus, 1)
        adaption = self.model.get_adaption_trace()
        self.model.set_variable("a_zero", adaption[-1])
        # print("FiCurve: model a_zero set to", adaption[-1])

    def calculate_all_frequency_points(self):

        sampling_interval = self.model.get_sampling_interval()
        self.f_inf_frequencies = []
        self.f_zero_frequencies = []
        self.f_baseline_frequencies = []

        for i, c in enumerate(self.stimulus_values):
            stimulus = SinusoidalStepStimulus(self.eod_frequency, c, self.stim_start, self.stim_duration)
            frequency_traces = []
            time_traces = []
            for j in range(self.trials):

                _, spiketimes = self.model.simulate(stimulus, self.total_simulation_time)
                self.spiketimes_array[i, j] = spiketimes
                trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
                frequency_traces.append(trial_frequency)
                time_traces.append(trial_time)

            time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval)
            self.mean_frequency_traces.append(frequency)
            self.mean_time_traces.append(time)

            if len(time) == 0 or min(time) > self.stim_start \
                    or max(time) < self.stim_start + self.stim_duration:
                # print("Too few spikes to calculate f_inf, f_0 and f_base")
                self.f_inf_frequencies.append(0)
                self.f_zero_frequencies.append(0)
                self.f_baseline_frequencies.append(0)
                continue

            f_inf, f_inf_idx = hF.detect_f_infinity_in_freq_trace(time, frequency, self.stim_start, self.stim_duration, sampling_interval)
            self.f_inf_frequencies.append(f_inf)
            self.indices_f_inf.append(f_inf_idx)

            f_zero, f_zero_idx = hF.detect_f_zero_in_frequency_trace(time, frequency, self.stim_start, sampling_interval)
            self.f_zero_frequencies.append(f_zero)
            self.indices_f_zero.append(f_zero_idx)

            f_baseline, f_base_idx = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval)
            self.f_baseline_frequencies.append(f_baseline)
            self.indices_f_baseline.append(f_base_idx)

    def get_mean_time_and_freq_traces(self):
        return self.mean_time_traces, self.mean_frequency_traces

    def get_sampling_interval(self):
        return self.model.get_sampling_interval()

    def get_delay(self):
        return 0

    def get_stimulus_start(self):
        return self.stim_start

    def get_stimulus_end(self):
        return self.stim_start + self.stim_duration

    def get_time_and_freq_traces(self):
        time_traces = []
        freq_traces = []
        for v in range(len(self.stimulus_values)):
            times_for_value = []
            freqs_for_value = []

            for s in self.spiketimes_array[v]:
                t, f = hF.calculate_time_and_frequency_trace(s, self.model.get_sampling_interval())
                times_for_value.append(t)
                freqs_for_value.append(f)

            time_traces.append(times_for_value)
            freq_traces.append(freqs_for_value)
        return time_traces, freq_traces

    def plot_f_point_detections(self, save_path=None):
        sampling_interval = self.model.get_sampling_interval()

        for i, c in enumerate(self.stimulus_values):
            time = self.mean_time_traces[i]
            frequency = self.mean_frequency_traces[i]

            if len(time) == 0 or min(time) > self.stim_start \
                    or max(time) < self.stim_start + self.stim_duration:
                continue
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
            ax.plot(time, frequency)
            start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], self.stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
                     label="f_base", color="deepskyblue")

            start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], self.stim_start, self.stim_duration, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
                     label="f_inf", color="limegreen")

            start_idx, end_idx = hF.time_window_detect_f_zero(time[0], self.stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
                     label="f_zero", color="orange")

            plt.legend()
            if save_path is not None:
                plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
            else:
                plt.show()

            plt.close()


def get_fi_curve_class(data, stimulus_values, eod_freq=None, trials=5, save_dir=None, recalculate=False) -> FICurve:
    if isinstance(data, CellData):
        return FICurveCellData(data, stimulus_values, save_dir, recalculate)
    if isinstance(data, LifacNoiseModel):
        if eod_freq is None:
            raise ValueError("The FiCurveModel needs the eod variable to work")
        return FICurveModel(data, stimulus_values, eod_freq, trials=trials)

    raise ValueError("Unknown type: Cannot find corresponding Baseline class. Data was type:" + str(type(data)))