from CellData import CellData
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import numpy as np
import matplotlib.pyplot as plt
from warnings import warn
import functions as fu
import helperFunctions as hF
from os.path import join, exists
import pickle
from sys import stderr


class FICurve:

    def __init__(self, stimulus_values, save_dir=None, recalculate=False):
        self.save_file_name = "fi_curve_values.pkl"
        self.stimulus_values = stimulus_values

        self.indices_f_baseline = []
        self.f_baseline_frequencies = []
        self.indices_f_inf = []
        self.f_inf_frequencies = []
        self.indices_f_zero = []
        self.f_zero_frequencies = []

        # increase, offset
        self.f_inf_fit = []
        # f_max, f_min, k, x_zero
        self.f_zero_fit = []

        if save_dir is None:
            self.initialize()
        else:
            if recalculate:
                self.initialize()
                self.save_values(save_dir)
            else:
                if not self.load_values(save_dir):
                    self.initialize()
                    self.save_values(save_dir)

    def initialize(self):
        self.calculate_all_frequency_points()
        self.f_inf_fit = hF.fit_clipped_line(self.stimulus_values, self.f_inf_frequencies)
        self.f_zero_fit = hF.fit_boltzmann(self.stimulus_values, self.f_zero_frequencies)

    def calculate_all_frequency_points(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_f_baseline_frequencies(self):
        return self.f_baseline_frequencies

    def get_f_inf_frequencies(self):
        return self.f_inf_frequencies

    def get_f_zero_frequencies(self):
        return self.f_zero_frequencies

    def get_f_inf_slope(self):
        if len(self.f_inf_fit) > 0:
            return self.f_inf_fit[0]

    def get_f_zero_fit_slope_at_straight(self):
        fit_vars = self.f_zero_fit
        return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

    def get_f_zero_fit_slope_at_stimulus_value(self, stimulus_value):
        fit_vars = self.f_zero_fit
        return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

    def get_f_inf_frequency_at_stimulus_value(self, stimulus_value):
        return fu.clipped_line(stimulus_value, self.f_inf_fit[0], self.f_inf_fit[1])

    def get_f_zero_and_f_inf_intersection(self):
        x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001)
        fit_vars = self.f_zero_fit
        f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
        f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1])

        intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten()
        # print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies])
        if len(intersection_indicies) > 1:
            f_baseline = np.median(self.f_baseline_frequencies)
            best_dist = np.inf
            best_idx = -1
            for idx in intersection_indicies:
                dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline)
                if dist < best_dist:
                    best_dist = dist
                    best_idx = idx

            return x_values[best_idx]

        elif len(intersection_indicies) == 0:
            raise ValueError("No intersection found!")
        else:
            return x_values[intersection_indicies[0]]

    def get_f_zero_fit_slope_at_f_inf_fit_intersection(self):
        x = self.get_f_zero_and_f_inf_intersection()
        fit_vars = self.f_zero_fit
        return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

    def get_mean_time_and_freq_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_time_and_freq_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_sampling_interval(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_delay(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_stimulus_start(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_stimulus_end(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_stimulus_duration(self):
        return self.get_stimulus_end() - self.get_stimulus_start()

    def plot_mean_frequency_curves(self, save_path=None):
        time_traces, freq_traces = self.get_time_and_freq_traces()
        mean_times, mean_freqs = self.get_mean_time_and_freq_traces()
        for i, sv in enumerate(self.stimulus_values):

            for j in range(len(time_traces[i])):
                plt.plot(time_traces[i][j], freq_traces[i][j], color="gray", alpha=0.8)

            plt.plot(mean_times[i], mean_freqs[i], color="black")
            plt.xlabel("Time [s]")
            plt.ylabel("Frequency [Hz]")


            plt.title("Mean frequency at contrast {:.2f} ({:} trials)".format(sv, len(time_traces[i])))
            if save_path is None:
                plt.show()
            else:
                plt.savefig(save_path + "mean_frequency_contrast_{:.2f}.png".format(sv))
            plt.close()

    def plot_fi_curve(self, save_path=None):
        min_x = min(self.stimulus_values)
        max_x = max(self.stimulus_values)
        step = (max_x - min_x) / 5000
        x_values = np.arange(min_x, max_x, step)

        plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base')

        plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf')
        plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values],
                 color='darkgreen', label='f_inf_fit')

        plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero')
        popt = self.f_zero_fit
        plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                 color='red', label='f_0_fit')

        plt.legend()
        plt.ylabel("Frequency [Hz]")
        plt.xlabel("Stimulus value")

        if save_path is None:
            plt.show()
        else:
            plt.savefig(save_path + "fi_curve.png")
        plt.close()

    @staticmethod
    def plot_fi_curve_comparision(data_fi_curve, model_fi_curve, save_path=None):
        min_x = min(min(data_fi_curve.stimulus_values), min(model_fi_curve.stimulus_values))
        max_x = max(max(data_fi_curve.stimulus_values), max(model_fi_curve.stimulus_values))
        step = (max_x - min_x) / 5000
        x_values = np.arange(min_x, max_x+step, step)

        fig, axes = plt.subplots(1, 3, sharex="all", sharey='all', figsize=(15, 6))
        # plot baseline
        data_origin = (data_fi_curve, model_fi_curve)
        f_base_color = ("blue", "deepskyblue")
        f_inf_color = ("green", "limegreen")
        f_zero_color = ("red", "orange")
        for i in range(2):

            axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_baseline_frequencies(),
                         color=f_base_color[i], label='f_base')

            axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_inf_frequencies(),
                         'o', color=f_inf_color[i], label='f_inf')
            y_values = [fu.clipped_line(x, data_origin[i].f_inf_fit[0], data_origin[i].f_inf_fit[1]) for x in x_values]
            axes[i].plot(x_values, y_values, color=f_inf_color[i], label='f_inf_fit')

            axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_zero_frequencies(),
                         'o', color=f_zero_color[i], label='f_zero')
            popt = data_origin[i].f_zero_fit
            axes[i].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                         color=f_zero_color[i], label='f_0_fit')

            axes[i].set_xlabel("Stimulus value - contrast")
            axes[i].legend()

        axes[0].set_title("cell")
        axes[0].set_ylabel("Frequency [Hz]")
        axes[1].set_title("model")

        median_baseline = np.median(data_fi_curve.get_f_baseline_frequencies())
        axes[2].plot((min_x, max_x), (median_baseline, median_baseline), color=f_base_color[0], label="cell med base")
        axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_baseline_frequencies(),
                     'o', color=f_base_color[1], label='model base')

        y_values = [fu.clipped_line(x, data_fi_curve.f_inf_fit[0], data_fi_curve.f_inf_fit[1]) for x in x_values]
        axes[2].plot(x_values, y_values, color=f_inf_color[0], label='f_inf_fit cell')
        axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_inf_frequencies(),
                     'o', color=f_inf_color[1], label='f_inf model')

        popt = data_fi_curve.f_zero_fit
        axes[2].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                     color=f_zero_color[0], label='f_0_fit cell')
        axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_zero_frequencies(),
                     'o', color=f_zero_color[1], label='f_zero model')
        axes[2].set_title("cell model comparision")
        axes[2].set_xlabel("Stimulus value - contrast")
        axes[2].legend()

        if save_path is None:
            plt.show()
        else:
            plt.savefig(save_path + "fi_curve_comparision.png")
        plt.close()

    def plot_f_point_detections(self, save_path=None):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def save_values(self, save_directory):

        values = {}
        values["stimulus_values"] = self.stimulus_values
        values["f_baseline_frequencies"] = self.f_baseline_frequencies
        values["f_inf_frequencies"] = self.f_inf_frequencies
        values["f_zero_frequencies"] = self.f_zero_frequencies
        values["f_inf_fit"] = self.f_inf_fit
        values["f_zero_fit"] = self.f_zero_fit

        with open(join(save_directory, self.save_file_name), "wb") as file:
            pickle.dump(values, file)

        print("Fi-Curve: Values saved!")

    def load_values(self, save_directory):
        file_path = join(save_directory, self.save_file_name)
        if not exists(file_path):
            print("Fi-Curve: No file to load")
            return False

        file = open(file_path, "rb")
        values = pickle.load(file)
        if set(values["stimulus_values"]) != set(self.stimulus_values):
            stderr.write("Fi-Curve:load_values() - Given stimulus values are different to the loaded ones!:\n "
                         "given: {}\n loaded: {}\n".format(str(self.stimulus_values), str(values["stimulus_values"])))

        self.stimulus_values = values["stimulus_values"]
        self.f_baseline_frequencies = values["f_baseline_frequencies"]
        self.f_inf_frequencies = values["f_inf_frequencies"]
        self.f_zero_frequencies = values["f_zero_frequencies"]
        self.f_inf_fit = values["f_inf_fit"]
        self.f_zero_fit = values["f_zero_fit"]
        print("Fi-Curve: Values loaded!")
        return True


class FICurveCellData(FICurve):

    def __init__(self, cell_data: CellData, stimulus_values, save_dir=None, recalculate=False):
        self.cell_data = cell_data
        super().__init__(stimulus_values, save_dir, recalculate)

    def calculate_all_frequency_points(self):
        mean_frequencies = self.cell_data.get_mean_fi_curve_isi_frequencies()
        time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies()
        stimulus_start = self.cell_data.get_stimulus_start()
        stimulus_duration = self.cell_data.get_stimulus_duration()
        sampling_interval = self.cell_data.get_sampling_interval()

        if len(mean_frequencies) == 0:
            warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n"
                 "Was all_calculate_mean_isi_frequencies already called?")

        for i in range(len(mean_frequencies)):

            if time_axes[i][0] > self.cell_data.get_stimulus_start():
                raise ValueError("TODO: Deal with to strongly cut frequency traces in cell data! ")
                # self.f_zero_frequencies.append(-1)
                # self.f_baseline_frequencies.append(-1)
                # self.f_inf_frequencies.append(-1)
                # continue

            f_zero, f_zero_idx = hF.detect_f_zero_in_frequency_trace(time_axes[i], mean_frequencies[i],
                                                         stimulus_start, sampling_interval)
            self.f_zero_frequencies.append(f_zero)
            self.indices_f_zero.append(f_zero_idx)

            f_baseline, f_base_idx = hF.detect_f_baseline_in_freq_trace(time_axes[i], mean_frequencies[i],
                                                            stimulus_start, sampling_interval)
            self.f_baseline_frequencies.append(f_baseline)
            self.indices_f_baseline.append(f_base_idx)
            f_infinity, f_inf_idx = hF.detect_f_infinity_in_freq_trace(time_axes[i], mean_frequencies[i],
                                                            stimulus_start, stimulus_duration, sampling_interval)
            self.f_inf_frequencies.append(f_infinity)
            self.indices_f_inf.append(f_inf_idx)

    def get_mean_time_and_freq_traces(self):
        return self.cell_data.get_time_axes_fi_curve_mean_frequencies(), self.cell_data.get_mean_fi_curve_isi_frequencies()

    def get_time_and_freq_traces(self):
        spiketimes = self.cell_data.get_fi_spiketimes()
        time_traces = []
        freq_traces = []
        for i in range(len(spiketimes)):
            trial_time_traces = []
            trial_freq_traces = []
            for j in range(len(spiketimes[i])):
                time, isi_freq = hF.calculate_time_and_frequency_trace(spiketimes[i][j], self.cell_data.get_sampling_interval())

                trial_freq_traces.append(isi_freq)
                trial_time_traces.append(time)

            time_traces.append(trial_time_traces)
            freq_traces.append(trial_freq_traces)

        return time_traces, freq_traces

    def get_sampling_interval(self):
        return self.cell_data.get_sampling_interval()

    def get_delay(self):
        return self.cell_data.get_delay()

    def get_stimulus_start(self):
        return self.cell_data.get_stimulus_start()

    def get_stimulus_end(self):
        return self.cell_data.get_stimulus_end()

    def get_f_zero_inverse_at_frequency(self, frequency):
        # UNUSED
        b_vars = self.f_zero_fit
        return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3])

    def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value):
        # UNUSED
        infty_vars = self.f_inf_fit
        return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1])

    def plot_f_point_detections(self, save_path=None):
        mean_frequencies = np.array(self.cell_data.get_mean_fi_curve_isi_frequencies())
        time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies()
        sampling_interval = self.cell_data.get_sampling_interval()
        stim_start = self.cell_data.get_stimulus_start()
        stim_duration = self.cell_data.get_stimulus_duration()

        for i, c in enumerate(self.stimulus_values):
            time = time_axes[i]
            frequency = mean_frequencies[i]

            if len(time) == 0 or min(time) > stim_start \
                    or max(time) < stim_start + stim_duration:
                continue
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
            ax.set_title("Stimulus value: {:.2f}".format(c))
            ax.plot(time, frequency)
            start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
                    label="f_base", color="deepskyblue")

            start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], stim_start, stim_duration,
                                                                  sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
                    label="f_inf", color="limegreen")

            start_idx, end_idx = hF.time_window_detect_f_zero(time[0], stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
                    label="f_zero", color="orange")

            plt.legend()
            if save_path is not None:
                plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
            else:
                plt.show()

            plt.close()


class FICurveModel(FICurve):
    stim_duration = 0.5
    stim_start = 0.5
    total_simulation_time = stim_duration + 2 * stim_start

    def __init__(self, model, stimulus_values, eod_frequency, trials=5):
        self.eod_frequency = eod_frequency
        self.model = model
        self.trials = trials
        self.spiketimes_array = np.zeros((len(stimulus_values), trials), dtype=list)
        self.mean_frequency_traces = []
        self.mean_time_traces = []
        self.set_model_adaption_to_baseline()
        super().__init__(stimulus_values)

    def set_model_adaption_to_baseline(self):
        stimulus = SinusoidalStepStimulus(self.eod_frequency, 0, 0, 0)
        self.model.simulate(stimulus, 1)
        adaption = self.model.get_adaption_trace()
        self.model.set_variable("a_zero", adaption[-1])
        # print("FiCurve: model a_zero set to", adaption[-1])

    def calculate_all_frequency_points(self):

        sampling_interval = self.model.get_sampling_interval()
        self.f_inf_frequencies = []
        self.f_zero_frequencies = []
        self.f_baseline_frequencies = []

        for i, c in enumerate(self.stimulus_values):
            stimulus = SinusoidalStepStimulus(self.eod_frequency, c, self.stim_start, self.stim_duration)
            frequency_traces = []
            time_traces = []
            for j in range(self.trials):

                _, spiketimes = self.model.simulate(stimulus, self.total_simulation_time)
                self.spiketimes_array[i, j] = spiketimes
                trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
                frequency_traces.append(trial_frequency)
                time_traces.append(trial_time)

            time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval)
            self.mean_frequency_traces.append(frequency)
            self.mean_time_traces.append(time)

            if len(time) == 0 or min(time) > self.stim_start \
                    or max(time) < self.stim_start + self.stim_duration:
                # print("Too few spikes to calculate f_inf, f_0 and f_base")
                self.f_inf_frequencies.append(0)
                self.f_zero_frequencies.append(0)
                self.f_baseline_frequencies.append(0)
                continue

            f_inf, f_inf_idx = hF.detect_f_infinity_in_freq_trace(time, frequency, self.stim_start, self.stim_duration, sampling_interval)
            self.f_inf_frequencies.append(f_inf)
            self.indices_f_inf.append(f_inf_idx)

            f_zero, f_zero_idx = hF.detect_f_zero_in_frequency_trace(time, frequency, self.stim_start, sampling_interval)
            self.f_zero_frequencies.append(f_zero)
            self.indices_f_zero.append(f_zero_idx)

            f_baseline, f_base_idx = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval)
            self.f_baseline_frequencies.append(f_baseline)
            self.indices_f_baseline.append(f_base_idx)

    def get_mean_time_and_freq_traces(self):
        return self.mean_time_traces, self.mean_frequency_traces

    def get_sampling_interval(self):
        return self.model.get_sampling_interval()

    def get_delay(self):
        return 0

    def get_stimulus_start(self):
        return self.stim_start

    def get_stimulus_end(self):
        return self.stim_start + self.stim_duration

    def get_time_and_freq_traces(self):
        time_traces = []
        freq_traces = []
        for v in range(len(self.stimulus_values)):
            times_for_value = []
            freqs_for_value = []

            for s in self.spiketimes_array[v]:
                t, f = hF.calculate_time_and_frequency_trace(s, self.model.get_sampling_interval())
                times_for_value.append(t)
                freqs_for_value.append(f)

            time_traces.append(times_for_value)
            freq_traces.append(freqs_for_value)
        return time_traces, freq_traces

    def plot_f_point_detections(self, save_path=None):
        sampling_interval = self.model.get_sampling_interval()

        for i, c in enumerate(self.stimulus_values):
            time = self.mean_time_traces[i]
            frequency = self.mean_frequency_traces[i]

            if len(time) == 0 or min(time) > self.stim_start \
                    or max(time) < self.stim_start + self.stim_duration:
                continue
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
            ax.plot(time, frequency)
            start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], self.stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
                     label="f_base", color="deepskyblue")

            start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], self.stim_start, self.stim_duration, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
                     label="f_inf", color="limegreen")

            start_idx, end_idx = hF.time_window_detect_f_zero(time[0], self.stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
                     label="f_zero", color="orange")

            plt.legend()
            if save_path is not None:
                plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
            else:
                plt.show()

            plt.close()


def get_fi_curve_class(data, stimulus_values, eod_freq=None, trials=5, save_dir=None, recalculate=False) -> FICurve:
    if isinstance(data, CellData):
        return FICurveCellData(data, stimulus_values, save_dir, recalculate)
    if isinstance(data, LifacNoiseModel):
        if eod_freq is None:
            raise ValueError("The FiCurveModel needs the eod variable to work")
        return FICurveModel(data, stimulus_values, eod_freq, trials=trials)

    raise ValueError("Unknown type: Cannot find corresponding Baseline class. Data was type:" + str(type(data)))