from os.path import isdir, exists
from warnings import warn
import pyrelacs.DataLoader as Dl
from models.AbstractModel import AbstractModel

UNKNOWN = -1
DAT_FORMAT = 0
NIX_FORMAT = 1
MODEL = 2


class AbstractParser:

    def cell_get_metadata(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_baseline_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_baseline_spiketimes(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_curve_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_curve_spiketimes(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_frequency_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_sampling_interval(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_recording_times(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def traces_available(self) -> bool:
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def spiketimes_available(self) -> bool:
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def frequencies_available(self) -> bool:
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")


class DatParser(AbstractParser):

    def __init__(self, dir_path):
        self.base_path = dir_path
        self.fi_file = self.base_path + "/fispikes1.dat"
        self.baseline_file = self.base_path + "/basespikes1.dat"
        self.stimuli_file = self.base_path + "/stimuli.dat"
        self.__test_data_file_existence__()

        self.fi_recording_times = []
        self.sampling_interval = -1

    def traces_available(self) -> bool:
        return True

    def frequencies_available(self) -> bool:
        return False

    def spiketimes_available(self) -> bool:
        return True

    def get_sampling_interval(self):
        if self.sampling_interval == -1:
            self.__read_sampling_interval__()

        return self.sampling_interval

    def get_recording_times(self):
        if len(self.fi_recording_times) == 0:
            self.__read_fi_recording_times__()
        return self.fi_recording_times

    def get_baseline_traces(self):
        return self.__get_traces__("BaselineActivity")

    def get_baseline_spiketimes(self):
        spiketimes = []
        warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.")

        for metadata, key, data in Dl.iload(self.baseline_file):
            spikes = data[:, 0]
            spiketimes.append(spikes)

        return spiketimes

    def get_fi_curve_traces(self):
        return self.__get_traces__("FICurve")

    def get_fi_frequency_traces(self):
        raise NotImplementedError("Not possible in .dat data type.\n"
                                  "Please check availability with the x_available functions.")

    # TODO clean up/ rewrite
    def get_fi_curve_spiketimes(self):
        spiketimes = []
        pre_intensities = []
        pre_durations = []
        intensities = []
        trans_amplitudes = []
        pre_duration = -1
        index = -1
        skip = False
        trans_amplitude = float('nan')
        for metadata, key, data in Dl.iload(self.fi_file):
            if len(metadata) != 0:

                metadata_index = 0

                if '----- Control --------------------------------------------------------' in metadata[0].keys():
                    metadata_index = 1
                    pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
                    trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
                    if pre_duration == 0:
                        skip = False
                    else:
                        skip = True
                        continue

                if skip:
                    continue

                intensity = float(metadata[metadata_index]['intensity'][:-2])
                pre_intensity = float(metadata[metadata_index]['preintensity'][:-2])

                intensities.append(intensity)
                pre_durations.append(pre_duration)
                pre_intensities.append(pre_intensity)
                trans_amplitudes.append(trans_amplitude)
                spiketimes.append([])
                index += 1

            if skip:
                continue

            if data.shape[1] != 1:
                raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!")

            spike_time_data = data[:, 0]/1000
            if len(spike_time_data) < 10:
                continue
            if spike_time_data[-1] < 1:
                print("# ignoring spike-train that ends before one second.")
                continue

            spiketimes[index].append(spike_time_data)

        # TODO add merging for similar intensities? hf.merge_similar_intensities() +  trans_amplitudes

        return trans_amplitudes, intensities, spiketimes

    def __get_traces__(self, repro):
        time_traces = []
        v1_traces = []
        eod_traces = []
        local_eod_traces = []
        stimulus_traces = []

        nothing = True

        for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
            nothing = False
            time_traces.append(time)
            v1_traces.append(x[0])
            eod_traces.append(x[1])
            local_eod_traces.append(x[2])
            stimulus_traces.append(x[3])

        traces = [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces]

        if nothing:
            warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!"
            warn(warn_msg)

        return traces

    def __read_fi_recording_times__(self):

        delays = []
        stim_duration = []
        pause = []

        for metadata, key, data in Dl.iload(self.fi_file):
            if len(metadata) != 0:
                control_key = '----- Control --------------------------------------------------------'
                if control_key in metadata[0].keys():
                    delays.append(float(metadata[0][control_key]["delay"][:-2])/1000)
                    pause.append(float(metadata[0][control_key]["pause"][:-2])/1000)
                    stim_key = "----- Test-Intensities -----------------------------------------------"
                    stim_duration.append(float(metadata[0][stim_key]["duration"][:-2])/1000)

        for l in [delays, stim_duration, pause]:
            if len(l) == 0:
                raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
                                   "Couldn't find any delay, stimulus duration and or pause in the metadata.\n" +
                                   "In file:" + self.base_path)
            elif len(set(l)) != 1:
                raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
                                   "Found multiple different delay, stimulus duration and or pause in the metadata.\n" +
                                   "In file:" + self.base_path)
            else:
                self.fi_recording_times = [-delays[0], 0, stim_duration[0], pause[0] - delays[0]]

    def __read_sampling_interval__(self):
        stop = False
        sampling_intervals = []
        for metadata, key, data in Dl.iload(self.stimuli_file):
            for md in metadata:
                for i in range(4):
                    key = "sample interval" + str(i+1)
                    if key in md.keys():

                        sampling_intervals.append(float(md[key][:-2]) / 1000)
                        stop = True
                    else:
                        break

            if stop:
                break

        if len(sampling_intervals) == 0:
            raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
                               "Sampling intervals not found in stimuli.dat this is not handled!\n" +
                               "with File:" + self.base_path)

        if len(set(sampling_intervals)) != 1:
            raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
                               "Sampling intervals not the same for all traces this is not handled!\n" +
                               "with File:" + self.base_path)
        else:
            self.sampling_interval = sampling_intervals[0]

    def __test_data_file_existence__(self):
        if not exists(self.stimuli_file):
            raise RuntimeError(self.stimuli_file + " file doesn't exist!")
        if not exists(self.fi_file):
            raise RuntimeError(self.fi_file + " file doesn't exist!")


# MODEL PARSER: ------------------------------

class ModelParser(AbstractParser):

    def __init__(self, model: AbstractModel):
        self.model = model

    def cell_get_metadata(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_baseline_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_curve_traces(self):
        if not self.model.simulates_voltage_trace():
            raise NotImplementedError("Model doesn't simulated voltage traces!")

        traces = []
        for stimulus in self.model.get_stimuli_for_fi_curve():
            self.model.simulate(stimulus, self.model.total_stimulation_time_fi_curve)
            traces.append(self.model.get_voltage_trace())

        return traces

    def get_fi_curve_spiketimes(self):
        if not self.model.simulates_spiketimes():
            raise NotImplementedError("Model doesn't simulated spiketimes!")

        all_spiketimes = []
        for stimulus in self.model.get_stimuli_for_fi_curve():
            self.model.simulate(stimulus, self.model.total_stimulation_time_fi_curve)
            all_spiketimes.append(self.model.get_spiketimes())

        return all_spiketimes

    def get_fi_frequency_traces(self):
        if not self.model.simulates_frequency():
            raise NotImplementedError("Model doesn't simulated frequency!")

        frequency_traces = []
        for stimulus in self.model.get_stimuli_for_fi_curve():
            self.model.simulate(stimulus, self.model.total_stimulation_time_fi_curve)
            frequency_traces.append(self.model.get_frequency())

        return frequency_traces

    def get_sampling_interval(self):
        self.model.get_sampling_interval()

    def get_recording_times(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def traces_available(self) -> bool:
        return self.model.simulates_voltage_trace()

    def spiketimes_available(self) -> bool:
        return self.model.simulates_spiketimes()

    def frequencies_available(self) -> bool:
        return self.model.simulates_frequency()

# TODO ####################################

class NixParser(AbstractParser):

    def __init__(self, nix_file_path):
        self.file_path = nix_file_path
        warn("NIX PARSER: NOT YET IMPLEMENTED!")
# TODO ####################################


def get_parser(data_path) -> AbstractParser:
    data_format = __test_for_format__(data_path)

    if data_format == DAT_FORMAT:
        return DatParser(data_path)
    elif data_format == NIX_FORMAT:
        return NixParser(data_path)
    elif data_format == MODEL:
        return ModelParser(data_path)
    elif data_format == UNKNOWN:
        raise TypeError("DataParserFactory:get_parser(data_path):\nCannot determine type of data for:" + data_path)


def __test_for_format__(data_path):
    if isinstance(data_path, AbstractModel):
        return MODEL

    if isdir(data_path):
        if exists(data_path + "/fispikes1.dat"):
            return DAT_FORMAT

    elif data_path.endswith(".nix"):
        return NIX_FORMAT
    else:
        return UNKNOWN