from os.path import isdir, exists
from warnings import warn
import pyrelacs.DataLoader as Dl

UNKNOWN = -1
DAT_FORMAT = 0
NIX_FORMAT = 1


class AbstractParser:

    def cell_get_metadata(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_baseline_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_curve_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_curve_spiketimes(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_sampling_interval(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_recording_times(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")


class DatParser(AbstractParser):

    def __init__(self, dir_path):
        self.base_path = dir_path
        self.fi_file = self.base_path + "/fispikes1.dat"
        self.stimuli_file = self.base_path + "/stimuli.dat"
        self.__test_data_file_existence__()

        self.fi_recording_times = []
        self.sampling_interval = -1

    def cell_get_metadata(self):
        pass

    def get_sampling_interval(self):
        if self.sampling_interval == -1:
            self.__read_sampling_interval__()

        return self.sampling_interval

    def get_recording_times(self):
        if len(self.fi_recording_times) == 0:
            self.__read_fi_recording_times__()
        return self.fi_recording_times

    def get_baseline_traces(self):
        return self.__get_traces__("BaselineActivity")

    def get_fi_curve_traces(self):
        return self.__get_traces__("FICurve")

    # TODO clean up/ rewrite
    def get_fi_curve_spiketimes(self):
        spiketimes = []
        pre_intensities = []
        pre_durations = []
        intensities = []
        trans_amplitudes = []
        pre_duration = -1
        index = -1
        skip = False
        trans_amplitude = float('nan')
        for metadata, key, data in Dl.iload(self.fi_file):
            if len(metadata) != 0:

                metadata_index = 0

                if '----- Control --------------------------------------------------------' in metadata[0].keys():
                    metadata_index = 1
                    pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
                    trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
                    if pre_duration == 0:
                        skip = False
                    else:
                        skip = True
                        continue

                if skip:
                    continue

                intensity = float(metadata[metadata_index]['intensity'][:-2])
                pre_intensity = float(metadata[metadata_index]['preintensity'][:-2])

                intensities.append(intensity)
                pre_durations.append(pre_duration)
                pre_intensities.append(pre_intensity)
                trans_amplitudes.append(trans_amplitude)
                spiketimes.append([])
                index += 1

            if skip:
                continue

            if data.shape[1] != 1:
                raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!")

            spike_time_data = data[:, 0]/1000
            if len(spike_time_data) < 10:
                continue
            if spike_time_data[-1] < 1:
                print("# ignoring spike-train that ends before one second.")
                continue

            spiketimes[index].append(spike_time_data)

        # TODO add merging for similar intensities? hf.merge_similar_intensities() +  trans_amplitudes

        return trans_amplitudes, intensities, spiketimes

    def __get_traces__(self, repro):
        time_traces = []
        v1_traces = []
        eod_traces = []
        local_eod_traces = []
        stimulus_traces = []

        nothing = True

        for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
            nothing = False
            time_traces.append(time)
            v1_traces.append(x[0])
            eod_traces.append(x[1])
            local_eod_traces.append(x[2])
            stimulus_traces.append(x[3])

        traces = [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces]

        if nothing:
            warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!"
            warn(warn_msg)

        return traces

    def __read_fi_recording_times__(self):

        delays = []
        stim_duration = []
        pause = []

        for metadata, key, data in Dl.iload(self.fi_file):
            if len(metadata) != 0:
                control_key = '----- Control --------------------------------------------------------'
                if control_key in metadata[0].keys():
                    delays.append(float(metadata[0][control_key]["delay"][:-2])/1000)
                    pause.append(float(metadata[0][control_key]["pause"][:-2])/1000)
                    stim_key = "----- Test-Intensities -----------------------------------------------"
                    stim_duration.append(float(metadata[0][stim_key]["duration"][:-2])/1000)

        for l in [delays, stim_duration, pause]:
            if len(l) == 0:
                raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
                                   "Couldn't find any delay, stimulus duration and or pause in the metadata.\n" +
                                   "In file:" + self.base_path)
            elif len(set(l)) != 1:
                raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
                                   "Found multiple different delay, stimulus duration and or pause in the metadata.\n" +
                                   "In file:" + self.base_path)
            else:
                self.fi_recording_times = [-delays[0], 0, stim_duration[0], pause[0] - delays[0]]

    def __read_sampling_interval__(self):
        stop = False
        sampling_intervals = []
        for metadata, key, data in Dl.iload(self.stimuli_file):
            for md in metadata:
                for i in range(4):
                    key = "sample interval" + str(i+1)
                    if key in md.keys():

                        sampling_intervals.append(float(md[key][:-2]) / 1000)
                        stop = True
                    else:
                        break

            if stop:
                break

        if len(sampling_intervals) == 0:
            raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
                               "Sampling intervals not found in stimuli.dat this is not handled!\n" +
                               "with File:" + self.base_path)

        if len(set(sampling_intervals)) != 1:
            raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
                               "Sampling intervals not the same for all traces this is not handled!\n" +
                               "with File:" + self.base_path)
        else:
            self.sampling_interval = sampling_intervals[0]

    def __test_data_file_existence__(self):
        if not exists(self.stimuli_file):
            raise RuntimeError(self.stimuli_file + " file doesn't exist!")
        if not exists(self.fi_file):
            raise RuntimeError(self.fi_file + " file doesn't exist!")


# TODO ####################################
class NixParser(AbstractParser):

    def __init__(self, nix_file_path):
        self.file_path = nix_file_path
        warn("NIX PARSER: NOT YET IMPLEMENTED!")
# TODO ####################################


def get_parser(data_path: str) -> AbstractParser:
    data_format = __test_for_format__(data_path)

    if data_format == DAT_FORMAT:
        return DatParser(data_path)
    elif data_format == NIX_FORMAT:
        return NixParser(data_path)
    elif UNKNOWN:
        raise TypeError("DataParserFactory:get_parser(data_path):\nCannot determine type of data for:" + data_path)


def __test_for_format__(data_path):
    if isdir(data_path):
        if exists(data_path + "/fispikes1.dat"):
            return DAT_FORMAT

    elif data_path.endswith(".nix"):
        return NIX_FORMAT
    else:
        return UNKNOWN