import DataParserFactory as dpf
from warnings import warn
from os import listdir
import helperFunctions as hf
import numpy as np


def icelldata_of_dir(base_path):
    for item in sorted(listdir(base_path)):
        item_path = base_path + item

        try:
            data = CellData(item_path)
            trace = data.get_base_traces(trace_type=data.V1)
            if len(trace) == 0:
                print("NO V1 TRACE FOUND: ", item_path)
                continue
            else:
                yield data

        except TypeError as e:
            warn_msg = str(e)
            warn(warn_msg)


class CellData:
    # Class to capture all the data of a single cell across all experiments (base rate, FI-curve, .?.)
    # should be abstract from the way the data is saved in the background .dat vs .nix

    # traces list of lists with traces: [[time], [voltage (v1)], [EOD], [local eod], [stimulus]]
    TIME = 0
    V1 = 1
    EOD = 2
    LOCAL_EOD = 3
    STIMULUS = 4

    def __init__(self, data_path):
        self.data_path = data_path
        self.parser = dpf.get_parser(data_path)

        self.base_traces = None
        self.base_spikes = None
        # self.fi_traces = None
        self.fi_intensities = None
        self.fi_spiketimes = None
        self.fi_trans_amplitudes = None
        self.mean_isi_frequencies = None
        self.time_axes = None
        # self.metadata = None

        self.sam_spiketimes = None
        self.sam_contrasts = None
        self.sam_delta_fs = None
        self.sam_eod_freqs = None
        self.sam_durations = None
        self.sam_trans_amplitudes = None



        self.sampling_interval = self.parser.get_sampling_interval()
        self.recording_times = self.parser.get_recording_times()

    def get_data_path(self):
        return self.data_path

    def get_base_traces(self, trace_type=None):
        if self.base_traces is None:
            self.base_traces = self.parser.get_baseline_traces()

        if trace_type is None:
            return self.base_traces
        else:
            return self.base_traces[trace_type]

    def get_base_spikes(self):
        if self.base_spikes is None:

            times = self.get_base_traces(self.TIME)
            eods = self.get_base_traces(self.EOD)
            v1_traces = self.get_base_traces(self.V1)
            spiketimes = []
            for i in range(len(times)):
                spiketimes.append(hf.detect_spiketimes(times[i], v1_traces[i]))
            self.base_spikes = spiketimes
        return self.base_spikes

    def get_base_isis(self):
        spikestimes = self.get_base_spikes()

        isis = []
        for spikes in spikestimes:
            isis.extend(np.diff(spikes))

        return isis

    def get_fi_traces(self):
        raise NotImplementedError("CellData:get_fi_traces():\n" +
                                  "Getting the Fi-Traces currently overflows the RAM and causes swapping! Reimplement if really needed!")
        # if self.fi_traces is None:
        #    self.fi_traces = self.parser.get_fi_curve_traces()
        # return self.fi_traces

    def get_fi_spiketimes(self):
        self.__read_fi_spiketimes_info__()
        return self.fi_spiketimes

    def get_fi_intensities(self):
        self.__read_fi_spiketimes_info__()
        return self.fi_intensities

    def get_fi_contrasts(self):
        self.__read_fi_spiketimes_info__()
        contrast = []
        for i in range(len(self.fi_intensities)):
            contrast.append((self.fi_intensities[i] - self.fi_trans_amplitudes[i]) / self.fi_trans_amplitudes[i])

        return contrast

    def get_sam_spiketimes(self):
        self.__read_sam_info__()
        return self.sam_spiketimes

    def get_sam_contrasts(self):
        self.__read_sam_info__()
        return self.sam_contrasts

    def get_sam_delta_frequencies(self):
        self.__read_sam_info__()
        return self.sam_delta_fs

    def get_sam_durations(self):
        self.__read_sam_info__()
        return self.sam_durations

    def get_sam_eod_frequencies(self):
        self.__read_sam_info__()
        return self.sam_eod_freqs

    def get_sam_trans_amplitudes(self):
        self.__read_sam_info__()
        return self.sam_trans_amplitudes

    def get_mean_fi_curve_isi_frequencies(self):
        if self.mean_isi_frequencies is None:
            self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequency_traces(
                self.get_fi_spiketimes(), self.get_sampling_interval())

        return self.mean_isi_frequencies

    def get_time_axes_fi_curve_mean_frequencies(self):
        if self.time_axes is None:
            self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequency_traces(
                self.get_fi_spiketimes(), self.get_sampling_interval())

        return self.time_axes

    def get_base_frequency(self):
        base_freqs = []
        for freq in self.get_mean_fi_curve_isi_frequencies():
            delay = self.get_delay()
            sampling_interval = self.get_sampling_interval()
            if delay < 0.1:
                warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.")

            idx_start = int(0.025 / sampling_interval)
            idx_end = int((delay - 0.025) / sampling_interval)
            base_freqs.append(np.mean(freq[idx_start:idx_end]))

        return np.median(base_freqs)

    def get_sampling_interval(self) -> float:
        return self.sampling_interval

    def get_recording_times(self) -> list:
        return self.recording_times

    def get_time_start(self) -> float:
        return self.recording_times[0]

    def get_delay(self) -> float:
        return abs(self.recording_times[0])

    def get_time_end(self) -> float:
        return self.recording_times[2] + self.recording_times[3]

    def get_stimulus_start(self) -> float:
        return self.recording_times[1]

    def get_stimulus_duration(self) -> float:
        return self.recording_times[2]

    def get_stimulus_end(self) -> float:
        return self.get_stimulus_start() + self.get_stimulus_duration()

    def get_after_stimulus_duration(self) -> float:
        return self.recording_times[3]

    def get_eod_frequency(self):
        eods = self.get_base_traces(self.EOD)
        sampling_interval = self.get_sampling_interval()
        frequencies = []
        for eod in eods:
            time = np.arange(0, len(eod) * sampling_interval, sampling_interval)
            frequencies.append(hf.calculate_eod_frequency(time, eod))

        return np.mean(frequencies)

    def __read_fi_spiketimes_info__(self):
        if self.fi_spiketimes is None:
            trans_amplitudes, intensities, spiketimes = self.parser.get_fi_curve_spiketimes()

            self.fi_intensities, self.fi_spiketimes, self.fi_trans_amplitudes = hf.merge_similar_intensities(
                intensities, spiketimes, trans_amplitudes)

    def __read_sam_info__(self):
        if self.sam_spiketimes is None:
            spiketimes, contrasts, delta_fs, eod_freqs, durations, trans_amplitudes = self.parser.get_sam_info()

            self.sam_spiketimes = spiketimes
            self.sam_contrasts = contrasts
            self.sam_delta_fs = delta_fs
            self.sam_eod_freqs = eod_freqs
            self.sam_durations = durations
            self.sam_trans_amplitudes = trans_amplitudes

    # def get_metadata(self):
    #     self.__read_metadata__()
    #     return self.metadata
    #
    # def get_metadata_item(self, item):
    #     self.__read_metadata__()
    #     if item in self.metadata.keys():
    #         return self.metadata[item]
    #     else:
    #         raise KeyError("CellData:get_metadata_item: Item not found in metadata! - " + str(item))
    #
    # def __read_metadata__(self):
    #     if self.metadata is None:
    #         # TODO!!
    #         pass