from CellData import CellData
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import helperFunctions as hF
import numpy as np
import matplotlib.pyplot as plt
import pickle
from os.path import join, exists


class Baseline:

    def __init__(self):
        self.save_file_name = "baseline_values.pkl"
        self.baseline_frequency = -1
        self.serial_correlation = []
        self.vector_strength = -1
        self.coefficient_of_variation = -1
        self.burstiness = -1

    def get_baseline_frequency(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_serial_correlation(self, max_lag):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_vector_strength(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_coefficient_of_variation(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_burstiness(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def __get_burstiness__(self, eod_freq):
        isis = np.array(self.get_interspike_intervals())
        if len(isis) == 0:
            return 0

        fullfilled = isis < (2.5 / eod_freq)
        perc_bursts = np.sum(fullfilled) / len(fullfilled)

        return perc_bursts * (np.mean(isis)*1000)

    def get_interspike_intervals(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_spiketime_phases(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def plot_baseline(self, save_path=None, time_length=0.2):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    @staticmethod
    def _get_baseline_frequency_given_data(spiketimes):
        base_freqs = []
        for st in spiketimes:
            base_freqs.append(hF.calculate_mean_isi_freq(st))

        return np.median(base_freqs)

    @staticmethod
    def _get_serial_correlation_given_data(max_lag, spikestimes):
        serial_cors = []

        for st in spikestimes:
            sc = hF.calculate_serial_correlation(st, max_lag)
            serial_cors.append(sc)
        serial_cors = np.array(serial_cors)

        res = np.mean(serial_cors, axis=0)
        return res

    @staticmethod
    def _get_vector_strength_given_data(times, eods, spiketimes, sampling_interval):
        vs_per_trial = []
        for i in range(len(spiketimes)):
            vs = hF.calculate_vector_strength_from_spiketimes(times[i], eods[i], spiketimes[i], sampling_interval)
            vs_per_trial.append(vs)

        return np.mean(vs_per_trial)

    @staticmethod
    def _get_coefficient_of_variation_given_data(spiketimes):
        # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
        cvs = []
        for st in spiketimes:
            st = np.array(st)
            cvs.append(hF.calculate_coefficient_of_variation(st))

        return np.mean(cvs)

    @staticmethod
    def _get_interspike_intervals_given_data(spiketimes):
        isis = []
        for st in spiketimes:
            st = np.array(st)
            isis.extend(np.diff(st))

        return isis

    @staticmethod
    def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, eod_freq="", save_path=None, position=0.5, time_length=0.2):
        """
        plots the stimulus / eod, together with the v1, spiketimes and frequency
        :return:
        """
        length_data_points = int(time_length / sampling_interval)

        start_idx = int(len(time) * position)
        start_idx = start_idx if start_idx >= 0 else 0
        end_idx = int(len(time) * position + length_data_points) + 1
        end_idx = end_idx if end_idx <= len(time) else len(time)

        spiketimes = np.array(spiketimes)
        spiketimes_part = spiketimes[(spiketimes >= time[start_idx]) & (spiketimes < time[end_idx])]

        fig, axes = plt.subplots(3, 1, sharex="col", figsize=(12, 8))
        fig.suptitle("Baseline middle part ({:.2f} seconds)".format(time_length))
        axes[0].plot(time[start_idx:end_idx], eod[start_idx:end_idx])
        axes[0].set_ylabel("Stimulus [mV] - Freq:" + eod_freq)

        max_v1 = max(v1[start_idx:end_idx])
        axes[1].plot(time[start_idx:end_idx], v1[start_idx:end_idx])
        axes[1].plot(spiketimes_part, [max_v1 for _ in range(len(spiketimes_part))],
                     'o', color='orange')
        axes[1].set_ylabel("V1-Trace [mV]")

        t, f = hF.calculate_time_and_frequency_trace(spiketimes_part, sampling_interval)
        axes[2].plot(t, f)
        axes[2].set_ylabel("ISI-Frequency [Hz]")
        axes[2].set_xlabel("Time [s]")

        if save_path is not None:
            plt.savefig(save_path + "baseline.png")
        else:
            plt.show()

        plt.close()

    @staticmethod
    def plot_isi_histogram_comparision(cell_isis, model_isis, save_path=None):
        cell_isis = np.array(cell_isis) * 1000
        model_isis = np.array(model_isis) * 1000
        maximum = max(max(cell_isis), max(model_isis))
        bins = np.arange(0, maximum * 1.01, 0.1)

        plt.title('Baseline ISIs')
        plt.xlabel('ISI in ms')
        plt.ylabel('Count')
        plt.hist(cell_isis, bins=bins, label="cell", alpha=0.5, density=True)
        plt.hist(model_isis, bins=bins, label="model", alpha=0.5, density=True)
        plt.legend()
        if save_path is not None:
            plt.savefig(save_path + "isi-histogram_comparision.png")
        else:
            plt.show()

        plt.close()

    def plot_polar_vector_strength(self, save_path=None):
        phases = self.get_spiketime_phases()
        fig = plt.figure()
        ax = fig.add_subplot(111, polar=True)
        # r = np.arange(0, 1, 0.001)
        # theta = 2 * 2 * np.pi * r
        # line, = ax.plot(theta, r, color='#ee8d18', lw=3)
        bins = np.arange(0, np.pi * 2, 0.1)
        ax.hist(phases, bins=bins)

        if save_path is not None:
            plt.savefig(save_path + "vector_strength_polar_plot.png")
        else:
            plt.show()

        plt.close()

    def plot_interspike_interval_histogram(self, save_path=None):

        isi = np.array(self.get_interspike_intervals()) * 1000  # change unit to milliseconds
        if len(isi) == 0:
            print("NON SPIKES IN BASELINE OF CELL/MODEL")
            plt.title('Baseline ISIs - NO SPIKES!')
            plt.xlabel('ISI in ms')
            plt.ylabel('Count')
            plt.hist(isi, bins=np.arange(0, 1, 0.1))

            if save_path is not None:
                plt.savefig(save_path + "isi-histogram.png")
            else:
                plt.show()

            plt.close()
            return
        maximum = max(isi)
        bins = np.arange(0, maximum * 1.01, 0.1)

        plt.title('Baseline ISIs')
        plt.xlabel('ISI in ms')
        plt.ylabel('Count')
        plt.hist(isi, bins=bins)

        if save_path is not None:
            plt.savefig(save_path + "isi-histogram.png")
        else:
            plt.show()

        plt.close()

    def plot_serial_correlation(self, max_lag, save_path=None):
        plt.title("Baseline Serial correlation")
        plt.xlabel("Lag")
        plt.ylabel("Correlation")
        plt.ylim((-1, 1))
        plt.plot(np.arange(1, max_lag+1, 1), self.get_serial_correlation(max_lag))

        if save_path is not None:
            plt.savefig(save_path + "serial_correlation.png")
        else:
            plt.show()

        plt.close()

    def save_values(self, save_directory):
        values = {}
        values["baseline_frequency"] = self.get_baseline_frequency()
        values["serial correlation"] = self.get_serial_correlation(max_lag=10)
        values["vector strength"] = self.get_vector_strength()
        values["coefficient of variation"] = self.get_coefficient_of_variation()
        values["burstiness"] = self.get_burstiness()

        with open(join(save_directory, self.save_file_name), "wb") as file:
            pickle.dump(values, file)
        print("Baseline: Values saved!")

    def load_values(self, save_directory):
        file_path = join(save_directory, self.save_file_name)
        if not exists(file_path):
            print("Baseline: No file to load")
            return False

        file = open(file_path, "rb")
        values = pickle.load(file)
        self.baseline_frequency = values["baseline_frequency"]
        self.serial_correlation = values["serial correlation"]
        self.vector_strength = values["vector strength"]
        self.coefficient_of_variation = values["coefficient of variation"]
        self.burstiness = values["burstiness"]
        print("Baseline: Values loaded!")
        return True


class BaselineCellData(Baseline):

    def __init__(self, cell_data: CellData):
        super().__init__()
        self.data = cell_data

    def get_baseline_frequency(self):
        if self.baseline_frequency == -1:
            spiketimes = self.data.get_base_spikes()
            self.baseline_frequency = self._get_baseline_frequency_given_data(spiketimes)

        return self.baseline_frequency

    def get_vector_strength(self):
        if self.vector_strength == -1:
            times = self.data.get_base_traces(self.data.TIME)
            eods = self.data.get_base_traces(self.data.EOD)
            spiketimes = self.data.get_base_spikes()
            sampling_interval = self.data.get_sampling_interval()
            self.vector_strength = self._get_vector_strength_given_data(times, eods, spiketimes, sampling_interval)
        return self.vector_strength

    def get_serial_correlation(self, max_lag):
        if len(self.serial_correlation) < max_lag:
            self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.data.get_base_spikes())
        return self.serial_correlation[:max_lag]

    def get_coefficient_of_variation(self):
        if self.coefficient_of_variation == -1:
            self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.data.get_base_spikes())
        return self.coefficient_of_variation

    def get_interspike_intervals(self):
        return self._get_interspike_intervals_given_data(self.data.get_base_spikes())

    def get_spiketime_phases(self):
        times = self.data.get_base_traces(self.data.TIME)
        spiketimes = self.data.get_base_spikes()
        eods = self.data.get_base_traces(self.data.EOD)
        sampling_interval = self.data.get_sampling_interval()

        phase_list = []
        for i in range(len(times)):
            spiketime_indices = np.array(np.around((np.array(spiketimes[i]) + times[i][0]) / sampling_interval), dtype=int)
            rel_spikes, eod_durs = hF.eods_around_spikes(times[i], eods[i], spiketime_indices)

            phase_times = (rel_spikes / eod_durs) * 2 * np.pi
            phase_list.extend(phase_times)

        return phase_list

    def get_burstiness(self):
        if self.burstiness == -1:
            self.burstiness = self.__get_burstiness__(self.data.get_eod_frequency())
        return self.burstiness

    def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
        # eod, v1, spiketimes, frequency

        time = self.data.get_base_traces(self.data.TIME)[0]
        eod = self.data.get_base_traces(self.data.EOD)[0]
        v1_trace = self.data.get_base_traces(self.data.V1)[0]
        spiketimes = self.data.get_base_spikes()[0]

        self._plot_baseline_given_data(time, eod, v1_trace, spiketimes,
                                       self.data.get_sampling_interval(), "{:.0f}".format(self.data.get_eod_frequency()), save_path, position, time_length)


class BaselineModel(Baseline):

    simulation_time = 30

    def __init__(self, model: LifacNoiseModel, eod_frequency, trials=1):
        super().__init__()
        self.model = model
        self.eod_frequency = eod_frequency
        self.set_model_adaption_to_baseline()

        self.stimulus = SinusoidalStepStimulus(eod_frequency, 0)
        self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval())
        self.time = np.arange(0, self.simulation_time, model.get_sampling_interval())

        self.v1_traces = []
        self.spiketimes = []
        for i in range(trials):
            v, st = model.simulate(self.stimulus, self.simulation_time)
            self.v1_traces.append(v)
            self.spiketimes.append(st)

    def set_model_adaption_to_baseline(self):
        stimulus = SinusoidalStepStimulus(self.eod_frequency, 0, 0, 0)
        self.model.simulate(stimulus, 1)
        adaption = self.model.get_adaption_trace()
        self.model.set_variable("a_zero", adaption[-1])
        # print("Baseline: model a_zero set to", adaption[-1])

    def get_baseline_frequency(self):
        if self.baseline_frequency == -1:
            self.baseline_frequency = self._get_baseline_frequency_given_data(self.spiketimes)
        return self.baseline_frequency

    def get_vector_strength(self):
        if self.vector_strength == -1:
            times = [self.time] * len(self.spiketimes)
            eods = [self.eod] * len(self.spiketimes)
            sampling_interval = self.model.get_sampling_interval()
            self.vector_strength = self._get_vector_strength_given_data(times, eods, self.spiketimes, sampling_interval)

        return self.vector_strength

    def get_serial_correlation(self, max_lag):
        if len(self.serial_correlation) != max_lag:
            self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.spiketimes)
        return self.serial_correlation

    def get_coefficient_of_variation(self):
        if self.coefficient_of_variation == -1:
            self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.spiketimes)
        return self.coefficient_of_variation

    def get_interspike_intervals(self):
        return self._get_interspike_intervals_given_data(self.spiketimes)

    def get_burstiness(self):
        if self.burstiness == -1:
            self.burstiness = self.__get_burstiness__(self.eod_frequency)
        return self.burstiness

    def get_spiketime_phases(self):
        sampling_interval = self.model.get_sampling_interval()

        phase_list = []
        for i in range(len(self.spiketimes)):
            spiketime_indices = np.array(np.around((np.array(self.spiketimes[i]) + self.time[0]) / sampling_interval), dtype=int)
            rel_spikes, eod_durs = hF.eods_around_spikes(self.time, self.eod, spiketime_indices)

            phase_times = (rel_spikes / eod_durs) * 2 * np.pi
            phase_list.extend(phase_times)

        return phase_list

    def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
        self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0],
                                       self.model.get_sampling_interval(), "{:.0f}".format(self.eod_frequency),
                                       save_path, position, time_length)


def get_baseline_class(data, eod_freq=None, trials=1) -> Baseline:
    if isinstance(data, CellData):
        return BaselineCellData(data)
    if isinstance(data, LifacNoiseModel):
        if eod_freq is None:
            raise ValueError("The EOD frequency is needed for the BaselineModel Class.")
        return BaselineModel(data, eod_freq, trials=trials)

    raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))