from os.path import isdir, exists
from warnings import warn
import pyrelacs.DataLoader as Dl
from models.AbstractModel import AbstractModel
import numpy as np

UNKNOWN = -1
DAT_FORMAT = 0
NIX_FORMAT = 1
MODEL = 2


class AbstractParser:

    # def cell_get_metadata(self):
    #     raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_baseline_length(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def has_sam_recordings(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_curve_contrasts(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_baseline_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_baseline_spiketimes(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_curve_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_curve_spiketimes(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_fi_frequency_traces(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_sam_info(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_sampling_interval(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def get_recording_times(self):
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def traces_available(self) -> bool:
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def spiketimes_available(self) -> bool:
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")

    def frequencies_available(self) -> bool:
        raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")


class DatParser(AbstractParser):

    def __init__(self, dir_path):
        self.base_path = dir_path
        self.info_file = self.base_path + "/info.dat"
        self.fi_file = self.base_path + "/fispikes1.dat"
        self.baseline_file = self.base_path + "/basespikes1.dat"
        self.sam_file = self.base_path + "/samallspikes1.dat"
        self.stimuli_file = self.base_path + "/stimuli.dat"
        self.__test_data_file_existence__()

        self.fi_recording_times = []
        self.sampling_interval = -1

    def has_sam_recordings(self):
        return exists(self.sam_file)

    def get_baseline_length(self):
        lengths = []
        for metadata, key, data in Dl.iload(self.baseline_file):
            if len(metadata) != 0:
                lengths.append(float(metadata[0]["duration"][:-3]))

        return lengths

    def get_species(self):
        species = ""
        for metadata in Dl.load(self.info_file):
            if "Species" in metadata.keys():
                species = metadata["Species"]
            elif "Subject" in metadata.keys():
                if isinstance(metadata["Subject"], dict) and "Species" in metadata["Subject"].keys():
                    species = metadata["Subject"]["Species"]

        return species

    def get_gender(self):
        gender = "not found"
        for metadata in Dl.load(self.info_file):
            if "Species" in metadata.keys():
                gender = metadata["Gender"]
            elif "Subject" in metadata.keys():
                if isinstance(metadata["Subject"], dict) and "Gender" in metadata["Subject"].keys():
                    gender = metadata["Subject"]["Gender"]

        return gender

    def get_quality(self):
        quality = ""
        for metadata in Dl.load(self.info_file):
            if "Recording quality" in metadata.keys():
                quality = metadata["Recording quality"]
            elif "Recording" in metadata.keys():
                if isinstance(metadata["Recording"], dict) and "Recording quality" in metadata["Recording"].keys():
                    quality = metadata["Recording"]["Recording quality"]
        return quality

    def get_cell_type(self):
        type = ""
        for metadata in Dl.load(self.info_file):
            if len(metadata.keys()) < 3:
                return ""
            if "CellType" in metadata.keys():
                type = metadata["CellType"]
            elif "Cell" in metadata.keys():
                if isinstance(metadata["Cell"], dict) and "CellType" in metadata["Cell"].keys():
                    type = metadata["Cell"]["CellType"]
        return type

    def get_fish_size(self):
        size = ""
        for metadata in Dl.load(self.info_file):
            if "Species" in metadata.keys():
                size = metadata["Size"]
            elif "Subject" in metadata.keys():
                if isinstance(metadata["Subject"], dict) and "Species" in metadata["Subject"].keys():
                    size = metadata["Subject"]["Size"]
        return size[:-2]

    def get_fi_curve_contrasts(self):
        """

        :return: list of tuples [(contrast, #_of_trials), ...]
        """
        contrasts = []
        contrast = [-1, float("nan")]
        for metadata, key, data in Dl.iload(self.fi_file):
            if len(metadata) != 0:
                if contrast[0] != -1:
                    contrasts.append(contrast)
                contrast = [-1, 1]
                contrast[0] = float(metadata[-1]["intensity"][:-2])
            else:
                contrast[1] += 1

        return np.array(contrasts)

    def traces_available(self) -> bool:
        return True

    def frequencies_available(self) -> bool:
        return False

    def spiketimes_available(self) -> bool:
        return True

    def get_sampling_interval(self):
        if self.sampling_interval == -1:
            self.__read_sampling_interval__()

        return self.sampling_interval

    def get_recording_times(self):
        if len(self.fi_recording_times) == 0:
            self.__read_fi_recording_times__()
        return self.fi_recording_times

    def get_baseline_traces(self):
        return self.__get_traces__("BaselineActivity")

    def get_baseline_spiketimes(self):
        # TODO change: reading from file -> detect from v1 trace
        spiketimes = []
        warn("Spiketimes don't fit time-wise to the baseline traces. Causes different vector strength angle per recording.")

        for metadata, key, data in Dl.iload(self.baseline_file):
            spikes = np.array(data[:, 0]) / 1000  # timestamps are saved in ms -> conversion to seconds
            spiketimes.append(spikes)

        return spiketimes

    def get_fi_curve_traces(self):
        return self.__get_traces__("FICurve")

    def get_fi_frequency_traces(self):
        raise NotImplementedError("Not possible in .dat data type.\n"
                                  "Please check availability with the x_available functions.")

    # TODO clean up/ rewrite
    def get_fi_curve_spiketimes(self):
        spiketimes = []
        pre_intensities = []
        pre_durations = []
        intensities = []
        trans_amplitudes = []
        pre_duration = -1
        index = -1
        skip = False
        trans_amplitude = float('nan')
        for metadata, key, data in Dl.iload(self.fi_file):
            if len(metadata) != 0:

                metadata_index = 0

                if '----- Control --------------------------------------------------------' in metadata[0].keys():
                    metadata_index = 1
                    pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
                    trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
                    if pre_duration == 0:
                        skip = False
                    else:
                        skip = True
                        continue
                else:
                    if "preduration" in metadata[0].keys():
                        pre_duration = float(metadata[0]["preduration"][:-2])
                        trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
                        if pre_duration == 0:
                            skip = False
                        else:
                            skip = True
                            continue

                if skip:
                    continue
                if 'intensity' in metadata[metadata_index].keys():
                    intensity = float(metadata[metadata_index]['intensity'][:-2])
                    pre_intensity = float(metadata[metadata_index]['preintensity'][:-2])
                else:
                    intensity = float(metadata[1-metadata_index]['intensity'][:-2])
                    pre_intensity = float(metadata[1-metadata_index]['preintensity'][:-2])

                intensities.append(intensity)
                pre_durations.append(pre_duration)
                pre_intensities.append(pre_intensity)
                trans_amplitudes.append(trans_amplitude)
                spiketimes.append([])
                index += 1

            if skip:
                continue

            if data.shape[1] != 1:
                raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!")

            spike_time_data = data[:, 0]/1000
            if len(spike_time_data) < 10:
                print("# ignoring spike-train that contains less than 10 spikes.")
                continue
            if spike_time_data[-1] < 1:
                print("# ignoring spike-train that ends before one second.")
                continue

            spiketimes[index].append(spike_time_data)

        # TODO Check if sorting works!
        new_order = np.arange(0, len(intensities), 1)
        intensities, new_order = zip(*sorted(zip(intensities, new_order)))
        intensities = list(intensities)
        spiketimes = [spiketimes[i] for i in new_order]
        trans_amplitudes = [trans_amplitudes[i] for i in new_order]

        for i in range(len(intensities)-1, -1, -1):
            if len(spiketimes[i]) < 3:
                del intensities[i]
                del spiketimes[i]
                del trans_amplitudes[i]

        return trans_amplitudes, intensities, spiketimes

    def get_sam_info(self):
        contrasts = []
        delta_fs = []
        spiketimes = []
        durations = []
        eod_freqs = []
        trans_amplitudes = []
        index = -1
        for metadata, key, data in Dl.iload(self.sam_file):
            factor = 1
            if key[0][0] == 'time':
                if key[1][0] == 'ms':
                    factor = 1/1000
                elif key[1][0] == 's':
                    factor = 1
                else:
                    print("DataParser Dat: Unknown time notation:", key[1][0])
            if len(metadata) != 0:
                if not "----- Stimulus -------------------------------------------------------" in metadata[0].keys():
                    eod_freq = float(metadata[0]["EOD rate"][:-2])  # in Hz
                    trans_amplitude = metadata[0]["trans. amplitude"][:-2]  # in mV

                    duration = float(metadata[0]["duration"][:-2]) * factor  # normally saved in ms? so change it with the factor
                    contrast = float(metadata[0]["contrast"][:-1])  # in percent
                    delta_f = float(metadata[0]["deltaf"][:-2])
                else:
                    stimulus_dict = metadata[0]["----- Stimulus -------------------------------------------------------"]
                    analysis_dict = metadata[0]["----- Analysis -------------------------------------------------------"]
                    eod_freq = float(metadata[0]["EOD rate"][:-2])  # in Hz
                    trans_amplitude = metadata[0]["trans. amplitude"][:-2]  # in mV

                    duration = float(stimulus_dict["duration"][:-2]) * factor  # normally saved in ms? so change it with the factor
                    contrast = float(stimulus_dict["contrast"][:-1])  # in percent
                    delta_f = float(stimulus_dict["deltaf"][:-2])

                # delta_f = metadata[0]["true deltaf"]
                # contrast = metadata[0]["true contrast"]

                contrasts.append(contrast)
                delta_fs.append(delta_f)
                durations.append(duration)
                eod_freqs.append(eod_freq)
                trans_amplitudes.append(trans_amplitude)
                spiketimes.append([])
                index += 1

            if data.shape[1] != 1:
                raise RuntimeError("DatParser:get_sam_spiketimes():\n read data has more than one dimension!")

            spike_time_data = data[:, 0] * factor  # saved in ms so use the factor to change it.
            if len(spike_time_data) < 10:
                continue
            if spike_time_data[-1] < 0.1:
                print("# ignoring spike-train that ends before one tenth of a second.")
                continue
            spiketimes[index].append(spike_time_data)

        return spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes

    def __get_traces__(self, repro):
        time_traces = []
        v1_traces = []
        eod_traces = []
        local_eod_traces = []
        stimulus_traces = []

        nothing = True

        for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
            nothing = False
            time_traces.append(time)
            v1_traces.append(x[0])
            eod_traces.append(x[1])
            local_eod_traces.append(x[2])
            stimulus_traces.append(x[3])

        traces = [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces]

        if nothing:
            warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!"
            warn(warn_msg)

        return traces

    def __iget_traces__(self, repro):

        for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
            # time, v1, eod, local_eod, stimulus
            yield time, x[0], x[1], x[2], x[3]

    def __read_fi_recording_times__(self):

        delays = []
        stim_duration = []
        pause = []

        for metadata, key, data in Dl.iload(self.fi_file):
            if len(metadata) != 0:
                control_key = '----- Control --------------------------------------------------------'
                if control_key in metadata[0].keys():
                    delays.append(float(metadata[0][control_key]["delay"][:-2])/1000)
                    pause.append(float(metadata[0][control_key]["pause"][:-2])/1000)
                    stim_key = "----- Test-Intensities -----------------------------------------------"
                    stim_duration.append(float(metadata[0][stim_key]["duration"][:-2])/1000)

                if "pause" in metadata[0].keys():
                    delays.append(float(metadata[0]["delay"][:-2]) / 1000)
                    pause.append(float(metadata[0]["pause"][:-2]) / 1000)
                    stim_duration.append(float(metadata[0]["duration"][:-2]) / 1000)

        for l in [delays, stim_duration, pause]:
            if len(l) == 0:
                raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
                                   "Couldn't find any delay, stimulus duration and or pause in the metadata.\n" +
                                   "In file:" + self.base_path)
            elif len(set(l)) != 1:
                raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
                                   "Found multiple different delay, stimulus duration and or pause in the metadata.\n" +
                                   "In file:" + self.base_path)
            else:
                self.fi_recording_times = [-delays[0], 0, stim_duration[0], pause[0] - delays[0]]

    def __read_sampling_interval__(self):
        stop = False
        sampling_intervals = []
        for metadata, key, data in Dl.iload(self.stimuli_file):
            for md in metadata:
                for i in range(4):
                    key = "sample interval" + str(i+1)
                    if key in md.keys():

                        sampling_intervals.append(float(md[key][:-2]) / 1000)
                        stop = True
                    else:
                        break

            if stop:
                break

        if len(sampling_intervals) == 0:
            raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
                               "Sampling intervals not found in stimuli.dat this is not handled!\n" +
                               "with File:" + self.base_path)

        if len(set(sampling_intervals)) != 1:
            raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
                               "Sampling intervals not the same for all traces this is not handled!\n" +
                               "with File:" + self.base_path)
        else:
            self.sampling_interval = sampling_intervals[0]

    def __test_data_file_existence__(self):
        if not exists(self.stimuli_file):
            raise FileNotFoundError(self.stimuli_file + " file doesn't exist!")
        if not exists(self.fi_file):
            raise FileNotFoundError(self.fi_file + " file doesn't exist!")
        if not exists(self.baseline_file):
            raise FileNotFoundError(self.baseline_file + " file doesn't exist!")
        # if not exists(self.sam_file):
        #     raise RuntimeError(self.sam_file + " file doesn't exist!")


def get_parser(data_path) -> AbstractParser:
    data_format = __test_for_format__(data_path)

    if data_format == DAT_FORMAT:
        return DatParser(data_path)
    elif data_format == NIX_FORMAT:
        raise NotImplementedError("DataParserFactory:get_parser(data_path): nix format doesn't have a parser yet")
    elif data_format == MODEL:
        raise NotImplementedError("DataParserFactory:get_parser(data_path): Model doesn't have a parser yet")
    elif data_format == UNKNOWN:
        raise TypeError("DataParserFactory:get_parser(data_path):\nCannot determine type of data for:" + data_path)


def __test_for_format__(data_path):
    if isinstance(data_path, AbstractModel):
        return MODEL

    if isdir(data_path):
        if exists(data_path + "/fispikes1.dat"):
            return DAT_FORMAT

    elif data_path.endswith(".nix"):
        return NIX_FORMAT
    else:
        return UNKNOWN