from parser.CellData import CellData
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
from my_util import helperFunctions as hF
import numpy as np
import matplotlib.pyplot as plt
import pickle
from os.path import join, exists


class Baseline:

    def __init__(self):
        self.save_file_name = "baseline_values.pkl"
        self.baseline_frequency = -1
        self.serial_correlation = []
        self.vector_strength = -1
        self.coefficient_of_variation = -1
        self.burstiness = -1

    def get_baseline_frequency(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_serial_correlation(self, max_lag):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_vector_strength(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_coefficient_of_variation(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_burstiness(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def __get_burstiness__(self, eod_freq):
        isis = np.array(self.get_interspike_intervals())
        if len(isis) == 0:
            return 0

        fullfilled = isis < (2.5 / eod_freq)
        perc_bursts = np.sum(fullfilled) / len(fullfilled)

        return perc_bursts * (np.mean(isis)*1000)

    def get_interspike_intervals(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_spiketime_phases(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def plot_baseline(self, save_path=None, time_length=0.2):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    @staticmethod
    def _get_baseline_frequency_given_data(spiketimes):
        base_freqs = []
        for st in spiketimes:
            base_freqs.append(hF.calculate_mean_isi_freq(st))

        return np.median(base_freqs)

    @staticmethod
    def _get_serial_correlation_given_data(max_lag, spikestimes):
        serial_cors = []

        for st in spikestimes:
            sc = hF.calculate_serial_correlation(st, max_lag)
            serial_cors.append(sc)
        serial_cors = np.array(serial_cors)

        res = np.mean(serial_cors, axis=0)
        return res

    @staticmethod
    def _get_vector_strength_given_data(times, eods, spiketimes, sampling_interval):
        vs_per_trial = []
        for i in range(len(spiketimes)):
            vs = hF.calculate_vector_strength_from_spiketimes(times[i], eods[i], spiketimes[i], sampling_interval)
            vs_per_trial.append(vs)

        return np.mean(vs_per_trial)

    @staticmethod
    def _get_coefficient_of_variation_given_data(spiketimes):
        # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
        cvs = []
        for st in spiketimes:
            st = np.array(st)
            cvs.append(hF.calculate_coefficient_of_variation(st))

        return np.mean(cvs)

    @staticmethod
    def _get_interspike_intervals_given_data(spiketimes):
        isis = []
        for st in spiketimes:
            st = np.array(st)
            isis.extend(np.diff(st))

        return isis

    @staticmethod
    def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, eod_freq="", save_path=None, position=0.5, time_length=0.2):
        """
        plots the stimulus / eod, together with the v1, spiketimes and frequency
        :return:
        """
        length_data_points = int(time_length / sampling_interval)

        start_idx = int(len(time) * position)
        start_idx = start_idx if start_idx >= 0 else 0
        end_idx = int(len(time) * position + length_data_points) + 1
        end_idx = end_idx if end_idx <= len(time) else len(time)

        spiketimes = np.array(spiketimes)
        spiketimes_part = spiketimes[(spiketimes >= time[start_idx]) & (spiketimes < time[end_idx])]

        fig, axes = plt.subplots(3, 1, sharex="col", figsize=(12, 8))
        fig.suptitle("Baseline middle part ({:.2f} seconds)".format(time_length))
        axes[0].plot(time[start_idx:end_idx], eod[start_idx:end_idx])
        axes[0].set_ylabel("Stimulus [mV] - Freq:" + eod_freq)

        max_v1 = max(v1[start_idx:end_idx])
        axes[1].plot(time[start_idx:end_idx], v1[start_idx:end_idx])
        axes[1].plot(spiketimes_part, [max_v1 for _ in range(len(spiketimes_part))],
                     'o', color='orange')
        axes[1].set_ylabel("V1-Trace [mV]")

        t, f = hF.calculate_time_and_frequency_trace(spiketimes_part, sampling_interval)
        axes[2].plot(t, f)
        axes[2].set_ylabel("ISI-Frequency [Hz]")
        axes[2].set_xlabel("Time [s]")

        if save_path is not None:
            plt.savefig(save_path + "baseline.png")
        else:
            plt.show()

        plt.close()

    @staticmethod
    def plot_isi_histogram_comparision(cell_isis, model_isis, save_path=None):
        cell_isis = np.array(cell_isis) * 1000
        model_isis = np.array(model_isis) * 1000
        maximum = max(max(cell_isis), max(model_isis))
        bins = np.arange(0, maximum * 1.01, 0.1)

        plt.title('Baseline ISIs')
        plt.xlabel('ISI in ms')
        plt.ylabel('Count')
        plt.hist(cell_isis, bins=bins, label="cell", alpha=0.5, density=True)
        plt.hist(model_isis, bins=bins, label="model", alpha=0.5, density=True)
        plt.legend()
        if save_path is not None:
            plt.savefig(save_path + "isi-histogram_comparision.png")
        else:
            plt.show()

        plt.close()

    def plot_polar_vector_strength(self, save_path=None):
        phases = self.get_spiketime_phases()
        fig = plt.figure()
        ax = fig.add_subplot(111, polar=True)
        # r = np.arange(0, 1, 0.001)
        # theta = 2 * 2 * np.pi * r
        # line, = ax.plot(theta, r, color='#ee8d18', lw=3)
        bins = np.arange(0, np.pi * 2, 0.1)
        ax.hist(phases, bins=bins)

        if save_path is not None:
            plt.savefig(save_path + "vector_strength_polar_plot.png")
        else:
            plt.show()

        plt.close()

    def plot_interspike_interval_histogram(self, save_path=None):

        isi = np.array(self.get_interspike_intervals()) * 1000  # change unit to milliseconds
        if len(isi) == 0:
            print("NON SPIKES IN BASELINE OF CELL/MODEL")
            plt.title('Baseline ISIs - NO SPIKES!')
            plt.xlabel('ISI in ms')
            plt.ylabel('Count')
            plt.hist(isi, bins=np.arange(0, 1, 0.1))

            if save_path is not None:
                plt.savefig(save_path + "isi-histogram.png")
            else:
                plt.show()

            plt.close()
            return
        maximum = max(isi)
        bins = np.arange(0, maximum * 1.01, 0.1)

        plt.title('Baseline ISIs')
        plt.xlabel('ISI in ms')
        plt.ylabel('Count')
        plt.hist(isi, bins=bins)

        if save_path is not None:
            plt.savefig(save_path + "isi-histogram.png")
        else:
            plt.show()

        plt.close()

    def plot_serial_correlation(self, max_lag, save_path=None):
        plt.title("Baseline Serial correlation")
        plt.xlabel("Lag")
        plt.ylabel("Correlation")
        plt.ylim((-1, 1))
        plt.plot(np.arange(1, max_lag+1, 1), self.get_serial_correlation(max_lag))

        if save_path is not None:
            plt.savefig(save_path + "serial_correlation.png")
        else:
            plt.show()

        plt.close()

    def save_values(self, save_directory):
        values = {}
        values["baseline_frequency"] = self.get_baseline_frequency()
        values["serial correlation"] = self.get_serial_correlation(max_lag=10)
        values["vector strength"] = self.get_vector_strength()
        values["coefficient of variation"] = self.get_coefficient_of_variation()
        values["burstiness"] = self.get_burstiness()

        with open(join(save_directory, self.save_file_name), "wb") as file:
            pickle.dump(values, file)
        print("Baseline: Values saved!")

    def load_values(self, save_directory):
        file_path = join(save_directory, self.save_file_name)
        if not exists(file_path):
            print("Baseline: No file to load")
            return False

        file = open(file_path, "rb")
        values = pickle.load(file)
        self.baseline_frequency = values["baseline_frequency"]
        self.serial_correlation = values["serial correlation"]
        self.vector_strength = values["vector strength"]
        self.coefficient_of_variation = values["coefficient of variation"]
        self.burstiness = values["burstiness"]
        print("Baseline: Values loaded!")
        return True


class BaselineCellData(Baseline):

    def __init__(self, cell_data: CellData):
        super().__init__()
        self.data = cell_data

    def get_baseline_frequency(self):
        if self.baseline_frequency == -1:
            spiketimes = self.data.get_base_spikes()
            self.baseline_frequency = self._get_baseline_frequency_given_data(spiketimes)

        return self.baseline_frequency

    def get_vector_strength(self):
        if self.vector_strength == -1:
            times = self.data.get_base_traces(self.data.TIME)
            eods = self.data.get_base_traces(self.data.EOD)
            spiketimes = self.data.get_base_spikes()
            sampling_interval = self.data.get_sampling_interval()
            self.vector_strength = self._get_vector_strength_given_data(times, eods, spiketimes, sampling_interval)
        return self.vector_strength

    def get_serial_correlation(self, max_lag):
        if len(self.serial_correlation) < max_lag:
            self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.data.get_base_spikes())
        return self.serial_correlation[:max_lag]

    def get_coefficient_of_variation(self):
        if self.coefficient_of_variation == -1:
            self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.data.get_base_spikes())
        return self.coefficient_of_variation

    def get_interspike_intervals(self):
        return self._get_interspike_intervals_given_data(self.data.get_base_spikes())

    def get_spiketime_phases(self):
        times = self.data.get_base_traces(self.data.TIME)
        spiketimes = self.data.get_base_spikes()
        eods = self.data.get_base_traces(self.data.EOD)
        sampling_interval = self.data.get_sampling_interval()

        phase_list = []
        for i in range(len(times)):
            spiketime_indices = np.array(np.around((np.array(spiketimes[i]) + times[i][0]) / sampling_interval), dtype=int)
            rel_spikes, eod_durs = hF.eods_around_spikes(times[i], eods[i], spiketime_indices)

            phase_times = (rel_spikes / eod_durs) * 2 * np.pi
            phase_list.extend(phase_times)

        return phase_list

    def get_burstiness(self):
        if self.burstiness == -1:
            self.burstiness = self.__get_burstiness__(self.data.get_eod_frequency())
        return self.burstiness

    def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
        # eod, v1, spiketimes, frequency

        time = self.data.get_base_traces(self.data.TIME)[0]
        eod = self.data.get_base_traces(self.data.EOD)[0]
        v1_trace = self.data.get_base_traces(self.data.V1)[0]
        spiketimes = self.data.get_base_spikes()[0]

        self._plot_baseline_given_data(time, eod, v1_trace, spiketimes,
                                       self.data.get_sampling_interval(), "{:.0f}".format(self.data.get_eod_frequency()), save_path, position, time_length)


class BaselineModel(Baseline):

    simulation_time = 30

    def __init__(self, model: LifacNoiseModel, eod_frequency, trials=1):
        super().__init__()
        self.model = model
        self.eod_frequency = eod_frequency
        self.set_model_adaption_to_baseline()

        self.stimulus = SinusoidalStepStimulus(eod_frequency, 0)
        self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval())
        self.time = np.arange(0, self.simulation_time, model.get_sampling_interval())

        self.v1_traces = []
        self.spiketimes = []
        for i in range(trials):
            v, st = model.simulate(self.stimulus, self.simulation_time)
            self.v1_traces.append(v)
            self.spiketimes.append(st)

    def set_model_adaption_to_baseline(self):
        stimulus = SinusoidalStepStimulus(self.eod_frequency, 0, 0, 0)
        self.model.simulate(stimulus, 1)
        adaption = self.model.get_adaption_trace()
        self.model.set_variable("a_zero", adaption[-1])
        # print("Baseline: model a_zero set to", adaption[-1])

    def get_baseline_frequency(self):
        if self.baseline_frequency == -1:
            self.baseline_frequency = self._get_baseline_frequency_given_data(self.spiketimes)
        return self.baseline_frequency

    def get_vector_strength(self):
        if self.vector_strength == -1:
            times = [self.time] * len(self.spiketimes)
            eods = [self.eod] * len(self.spiketimes)
            sampling_interval = self.model.get_sampling_interval()
            self.vector_strength = self._get_vector_strength_given_data(times, eods, self.spiketimes, sampling_interval)

        return self.vector_strength

    def get_serial_correlation(self, max_lag):
        if len(self.serial_correlation) != max_lag:
            self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.spiketimes)
        return self.serial_correlation

    def get_coefficient_of_variation(self):
        if self.coefficient_of_variation == -1:
            self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.spiketimes)
        return self.coefficient_of_variation

    def get_interspike_intervals(self):
        return self._get_interspike_intervals_given_data(self.spiketimes)

    def get_burstiness(self):
        if self.burstiness == -1:
            self.burstiness = self.__get_burstiness__(self.eod_frequency)
        return self.burstiness

    def get_spiketime_phases(self):
        sampling_interval = self.model.get_sampling_interval()

        phase_list = []
        for i in range(len(self.spiketimes)):
            spiketime_indices = np.array(np.around((np.array(self.spiketimes[i]) + self.time[0]) / sampling_interval), dtype=int)
            rel_spikes, eod_durs = hF.eods_around_spikes(self.time, self.eod, spiketime_indices)

            phase_times = (rel_spikes / eod_durs) * 2 * np.pi
            phase_list.extend(phase_times)

        return phase_list

    def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
        self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0],
                                       self.model.get_sampling_interval(), "{:.0f}".format(self.eod_frequency),
                                       save_path, position, time_length)


def get_baseline_class(data, eod_freq=None, trials=1) -> Baseline:
    if isinstance(data, CellData):
        return BaselineCellData(data)
    if isinstance(data, LifacNoiseModel):
        if eod_freq is None:
            raise ValueError("The EOD frequency is needed for the BaselineModel Class.")
        return BaselineModel(data, eod_freq, trials=trials)

    raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))