from CellData import CellData
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import numpy as np
import matplotlib.pyplot as plt
from warnings import warn
import functions as fu
import helperFunctions as hF


class FICurve:

    def __init__(self, stimulus_values):
        self.stimulus_values = stimulus_values

        self.f_baseline_frequencies = []
        self.f_inf_frequencies = []
        self.f_zero_frequencies = []

        # increase, offset
        self.f_inf_fit = []
        # f_max, f_min, k, x_zero
        self.f_zero_fit = []

        self.initialize()

    def initialize(self):
        self.calculate_all_frequency_points()
        self.f_inf_fit = hF.fit_clipped_line(self.stimulus_values, self.f_inf_frequencies)
        self.f_zero_fit = hF.fit_boltzmann(self.stimulus_values, self.f_zero_frequencies)

    def calculate_all_frequency_points(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_f_baseline_frequencies(self):
        return self.f_baseline_frequencies

    def get_f_inf_frequencies(self):
        return self.f_inf_frequencies

    def get_f_zero_frequencies(self):
        return self.f_zero_frequencies

    def get_f_inf_slope(self):
        if len(self.f_inf_fit) > 0:
            return self.f_inf_fit[0]

    def get_f_zero_fit_slope_at_straight(self):
        fit_vars = self.f_zero_fit
        return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

    def get_f_zero_fit_slope_at_stimulus_value(self, stimulus_value):
        fit_vars = self.f_zero_fit
        return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

    def get_f_inf_frequency_at_stimulus_value(self, stimulus_value):
        return fu.clipped_line(stimulus_value, self.f_inf_fit[0], self.f_inf_fit[1])

    def get_f_zero_and_f_inf_intersection(self):
        x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001)
        fit_vars = self.f_zero_fit
        f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
        f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1])

        intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten()
        # print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies])
        if len(intersection_indicies) > 1:
            f_baseline = np.median(self.f_baseline_frequencies)
            best_dist = np.inf
            best_idx = -1
            for idx in intersection_indicies:
                dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline)
                if dist < best_dist:
                    best_dist = dist
                    best_idx = idx

            return x_values[best_idx]

        elif len(intersection_indicies) == 0:
            raise ValueError("No intersection found!")
        else:
            return x_values[intersection_indicies[0]]

    def get_f_zero_fit_slope_at_f_inf_fit_intersection(self):
        x = self.get_f_zero_and_f_inf_intersection()
        fit_vars = self.f_zero_fit
        return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

    def get_mean_time_and_freq_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_time_and_freq_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_sampling_interval(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_delay(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_stimulus_start(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_stimulus_end(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_stimulus_duration(self):
        return self.get_stimulus_end() - self.get_stimulus_start()

    def plot_mean_frequency_curves(self, save_path=None):
        time_traces, freq_traces = self.get_time_and_freq_traces()
        mean_times, mean_freqs = self.get_mean_time_and_freq_traces()
        for i, sv in enumerate(self.stimulus_values):

            for j in range(len(time_traces[i])):
                plt.plot(time_traces[i][j], freq_traces[i][j], color="gray", alpha=0.8)

            plt.plot(mean_times[i], mean_freqs[i], color="black")
            plt.xlabel("Time [s]")
            plt.ylabel("Frequency [Hz]")


            plt.title("Mean frequency at contrast {:.2f} ({:} trials)".format(sv, len(time_traces[i])))
            if save_path is None:
                plt.show()
            else:
                plt.savefig(save_path + "mean_frequency_contrast_{:.2f}.png".format(sv))
            plt.close()

    def plot_fi_curve(self, save_path=None):
        min_x = min(self.stimulus_values)
        max_x = max(self.stimulus_values)
        step = (max_x - min_x) / 5000
        x_values = np.arange(min_x, max_x, step)

        plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base')

        plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf')
        plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values],
                 color='darkgreen', label='f_inf_fit')

        plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero')
        popt = self.f_zero_fit
        plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                 color='red', label='f_0_fit')

        plt.legend()
        plt.ylabel("Frequency [Hz]")
        plt.xlabel("Stimulus value")

        if save_path is None:
            plt.show()
        else:
            plt.savefig(save_path + "fi_curve.png")
        plt.close()

    @staticmethod
    def plot_fi_curve_comparision(data_fi_curve, model_fi_curve, save_path=None):
        min_x = min(min(data_fi_curve.stimulus_values), min(model_fi_curve.stimulus_values))
        max_x = max(max(data_fi_curve.stimulus_values), max(model_fi_curve.stimulus_values))
        step = (max_x - min_x) / 5000
        x_values = np.arange(min_x, max_x+step, step)

        fig, axes = plt.subplots(1, 3, sharex="all", sharey='all', figsize=(15, 6))
        # plot baseline
        data_origin = (data_fi_curve, model_fi_curve)
        f_base_color = ("blue", "deepskyblue")
        f_inf_color = ("green", "limegreen")
        f_zero_color = ("red", "orange")
        for i in range(2):

            axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_baseline_frequencies(),
                         color=f_base_color[i], label='f_base')

            axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_inf_frequencies(),
                         'o', color=f_inf_color[i], label='f_inf')
            y_values = [fu.clipped_line(x, data_origin[i].f_inf_fit[0], data_origin[i].f_inf_fit[1]) for x in x_values]
            axes[i].plot(x_values, y_values, color=f_inf_color[i], label='f_inf_fit')

            axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_zero_frequencies(),
                         'o', color=f_zero_color[i], label='f_zero')
            popt = data_origin[i].f_zero_fit
            axes[i].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                         color=f_zero_color[i], label='f_0_fit')

            axes[i].set_xlabel("Stimulus value - contrast")
            axes[i].legend()

        axes[0].set_title("cell")
        axes[0].set_ylabel("Frequency [Hz]")
        axes[1].set_title("model")

        median_baseline = np.median(data_fi_curve.get_f_baseline_frequencies())
        axes[2].plot((min_x, max_x), (median_baseline, median_baseline), color=f_base_color[0], label="cell med base")
        axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_baseline_frequencies(),
                     'o', color=f_base_color[1], label='model base')

        y_values = [fu.clipped_line(x, data_fi_curve.f_inf_fit[0], data_fi_curve.f_inf_fit[1]) for x in x_values]
        axes[2].plot(x_values, y_values, color=f_inf_color[0], label='f_inf_fit cell')
        axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_inf_frequencies(),
                     'o', color=f_inf_color[1], label='f_inf model')

        popt = data_fi_curve.f_zero_fit
        axes[2].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
                     color=f_zero_color[0], label='f_0_fit cell')
        axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_zero_frequencies(),
                     'o', color=f_zero_color[1], label='f_zero model')
        axes[2].set_title("cell model comparision")
        axes[2].set_xlabel("Stimulus value - contrast")
        axes[2].legend()

        if save_path is None:
            plt.show()
        else:
            plt.savefig(save_path + "fi_curve_comparision.png")
        plt.close()

    def plot_f_point_detections(self, save_path=None):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")


class FICurveCellData(FICurve):

    def __init__(self, cell_data: CellData, stimulus_values):
        self.cell_data = cell_data
        super().__init__(stimulus_values)

    def calculate_all_frequency_points(self):
        mean_frequencies = self.cell_data.get_mean_fi_curve_isi_frequencies()
        time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies()
        stimulus_start = self.cell_data.get_stimulus_start()
        stimulus_duration = self.cell_data.get_stimulus_duration()
        sampling_interval = self.cell_data.get_sampling_interval()

        if len(mean_frequencies) == 0:
            warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n"
                 "Was all_calculate_mean_isi_frequencies already called?")

        for i in range(len(mean_frequencies)):

            if time_axes[i][0] > self.cell_data.get_stimulus_start():
                raise ValueError("TODO: Deal with to strongly cut frequency traces in cell data! ")
                # self.f_zero_frequencies.append(-1)
                # self.f_baseline_frequencies.append(-1)
                # self.f_inf_frequencies.append(-1)
                # continue

            f_zero = hF.detect_f_zero_in_frequency_trace(time_axes[i], mean_frequencies[i],
                                                         stimulus_start, sampling_interval)
            self.f_zero_frequencies.append(f_zero)
            f_baseline = hF.detect_f_baseline_in_freq_trace(time_axes[i], mean_frequencies[i],
                                                            stimulus_start, sampling_interval)
            self.f_baseline_frequencies.append(f_baseline)
            f_infinity = hF.detect_f_infinity_in_freq_trace(time_axes[i], mean_frequencies[i],
                                                            stimulus_start, stimulus_duration, sampling_interval)
            self.f_inf_frequencies.append(f_infinity)

    def get_mean_time_and_freq_traces(self):
        return self.cell_data.get_time_axes_fi_curve_mean_frequencies(), self.cell_data.get_mean_fi_curve_isi_frequencies()

    def get_time_and_freq_traces(self):
        spiketimes = self.cell_data.get_fi_spiketimes()
        time_traces = []
        freq_traces = []
        for i in range(len(spiketimes)):
            trial_time_traces = []
            trial_freq_traces = []
            for j in range(len(spiketimes[i])):
                time, isi_freq = hF.calculate_time_and_frequency_trace(spiketimes[i][j], self.cell_data.get_sampling_interval())

                trial_freq_traces.append(isi_freq)
                trial_time_traces.append(time)

            time_traces.append(trial_time_traces)
            freq_traces.append(trial_freq_traces)

        return time_traces, freq_traces

    def get_sampling_interval(self):
        return self.cell_data.get_sampling_interval()

    def get_delay(self):
        return self.cell_data.get_delay()

    def get_stimulus_start(self):
        return self.cell_data.get_stimulus_start()

    def get_stimulus_end(self):
        return self.cell_data.get_stimulus_end()

    def get_f_zero_inverse_at_frequency(self, frequency):
        # UNUSED
        b_vars = self.f_zero_fit
        return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3])

    def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value):
        # UNUSED
        infty_vars = self.f_inf_fit
        return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1])

    def plot_f_point_detections(self, save_path=None):
        mean_frequencies = np.array(self.cell_data.get_mean_fi_curve_isi_frequencies())
        time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies()
        sampling_interval = self.cell_data.get_sampling_interval()
        stim_start = self.cell_data.get_stimulus_start()
        stim_duration = self.cell_data.get_stimulus_duration()

        for i, c in enumerate(self.stimulus_values):
            time = time_axes[i]
            frequency = mean_frequencies[i]

            if len(time) == 0 or min(time) > stim_start \
                    or max(time) < stim_start + stim_duration:
                continue
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
            ax.plot(time, frequency)
            start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
                    label="f_base", color="deepskyblue")

            start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], stim_start, stim_duration,
                                                                  sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
                    label="f_inf", color="limegreen")

            start_idx, end_idx = hF.time_window_detect_f_zero(time[0], stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
                    label="f_zero", color="orange")

            plt.legend()
            if save_path is not None:
                plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
            else:
                plt.show()

            plt.close()


class FICurveModel(FICurve):
    stim_duration = 0.5
    stim_start = 0.5
    total_simulation_time = stim_duration + 2 * stim_start

    def __init__(self, model, stimulus_values, eod_frequency, trials=5):
        self.eod_frequency = eod_frequency
        self.model = model
        self.trials = trials
        self.spiketimes_array = np.zeros((len(stimulus_values), trials), dtype=list)
        self.mean_frequency_traces = []
        self.mean_time_traces = []
        super().__init__(stimulus_values)

    def calculate_all_frequency_points(self):


        sampling_interval = self.model.get_sampling_interval()
        self.f_inf_frequencies = []
        self.f_zero_frequencies = []
        self.f_baseline_frequencies = []

        for i, c in enumerate(self.stimulus_values):
            stimulus = SinusoidalStepStimulus(self.eod_frequency, c, self.stim_start, self.stim_duration)
            frequency_traces = []
            time_traces = []
            for j in range(self.trials):

                _, spiketimes = self.model.simulate_fast(stimulus, self.total_simulation_time)
                self.spiketimes_array[i, j] = spiketimes
                trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
                frequency_traces.append(trial_frequency)
                time_traces.append(trial_time)

            time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval)
            self.mean_frequency_traces.append(frequency)
            self.mean_time_traces.append(time)

            if len(time) == 0 or min(time) > self.stim_start \
                    or max(time) < self.stim_start + self.stim_duration:
                print("Too few spikes to calculate f_inf, f_0 and f_base")
                self.f_inf_frequencies.append(0)
                self.f_zero_frequencies.append(0)
                self.f_baseline_frequencies.append(0)
                continue

            f_inf = hF.detect_f_infinity_in_freq_trace(time, frequency, self.stim_start, self.stim_duration, sampling_interval)
            self.f_inf_frequencies.append(f_inf)

            f_zero = hF.detect_f_zero_in_frequency_trace(time, frequency, self.stim_start, sampling_interval)
            self.f_zero_frequencies.append(f_zero)

            f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval)
            self.f_baseline_frequencies.append(f_baseline)

    def get_mean_time_and_freq_traces(self):
        return self.mean_time_traces, self.mean_frequency_traces

    def get_sampling_interval(self):
        return self.model.get_sampling_interval()

    def get_delay(self):
        return 0

    def get_stimulus_start(self):
        return self.stim_start

    def get_stimulus_end(self):
        return self.stim_start + self.stim_duration

    def get_time_and_freq_traces(self):
        time_traces = []
        freq_traces = []
        for v in range(len(self.stimulus_values)):
            times_for_value = []
            freqs_for_value = []

            for s in self.spiketimes_array[v]:
                t, f = hF.calculate_time_and_frequency_trace(s, self.model.get_sampling_interval())
                times_for_value.append(t)
                freqs_for_value.append(f)

            time_traces.append(times_for_value)
            freq_traces.append(freqs_for_value)
        return time_traces, freq_traces

    def plot_f_point_detections(self, save_path=None):
        sampling_interval = self.model.get_sampling_interval()

        for i, c in enumerate(self.stimulus_values):
            time = self.mean_time_traces[i]
            frequency = self.mean_frequency_traces[i]

            if len(time) == 0 or min(time) > self.stim_start \
                    or max(time) < self.stim_start + self.stim_duration:
                continue
            fig, ax = plt.subplots(1, 1, figsize=(8, 8))
            ax.plot(time, frequency)
            start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], self.stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
                     label="f_base", color="deepskyblue")

            start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], self.stim_start, self.stim_duration, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
                     label="f_inf", color="limegreen")

            start_idx, end_idx = hF.time_window_detect_f_zero(time[0], self.stim_start, sampling_interval)
            ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
                     label="f_zero", color="orange")

            plt.legend()
            if save_path is not None:
                plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
            else:
                plt.show()

            plt.close()


def get_fi_curve_class(data, stimulus_values, eod_freq=None) -> FICurve:
    if isinstance(data, CellData):
        return FICurveCellData(data, stimulus_values)
    if isinstance(data, LifacNoiseModel):
        if eod_freq is None:
            raise ValueError("The FiCurveModel needs the eod variable to work")
        return FICurveModel(data, stimulus_values, eod_freq)

    raise ValueError("Unknown type: Cannot find corresponding Baseline class. Data was type:" + str(type(data)))