from .frontend_classes import Dataset, RePro, Stimulus
from .util import BoltzmannFit, unzip_if_needed, gaussian_kernel, zero_crossings, spike_times_to_rate, StimSpikesFile
import numpy as np
import nixio as nix
from scipy.stats import circstd
# from scipy.optimize import curve_fit
import os
import subprocess
from tqdm import tqdm
import yaml

from IPython import embed


class BaselineData:
    """
    Class representing the Baseline data that has been recorded within a given Dataset.

    This class provides access to basic measures estimated from the baseline activity.
    """
    def __init__(self, dataset=None, dataset_id=None):
        d, _ = Dataset.find(dataset_id=dataset_id)
        if len(d) == 0 or len(d) > 1:
            raise ValueError("Dataset id not found or not unique")
        dataset = d[0]
        self.__spike_data = []
        self.__eod_data = []
        self.__eod_times = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self._get_data()

    def _get_data(self):
        if not self.__dataset:
            return
        self.__repros, _ = RePro.find("BaselineActivity", cell_id=self.__cell.id)
        for i in tqdm(range(len(self.__repros)), desc="loading data"):
            r = self.__repros[i]
            sd = self.__read_spike_data(r)
            if sd is not None and len(sd) > 1:
                self.__spike_data.append(sd)
            else:
                continue
            self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))

    def valid(self):
        # fixme implement me!
        pass

    def __read_spike_data(self, r: RePro):
        if self.__dataset.has_nix:
            return self.__read_spike_data_from_nix(r)
        else:
            return self.__read_spike_data_from_directory(r)

    def __read_eod_data(self, r: RePro, duration):
        if self.__dataset.has_nix:
            return self.__read_eod_data_from_nix(r, duration)
        else:
            return self.__read_eod_data_from_directory(r, duration)

    def __get_serial_correlation(self, times, max_lags=50):
        if times is None or len(times) < max_lags:
            return None
        isis = np.diff(times)
        unbiased = isis - np.mean(isis, 0)
        norm = sum(unbiased ** 2)
        a_corr = np.correlate(unbiased, unbiased, "same") / norm
        a_corr = a_corr[int(len(a_corr) / 2):]
        return a_corr[:max_lags]

    @property
    def baseline_rate(self):
        """The average baseline firing rate for each run of the baseline repro

        Returns:
            list of float: the average firing rate
        """
        rates = []
        for i in range(self.size):
            spikes = self.spikes(i)
            max_time = np.floor(spikes)[-1]
            min_time = np.ceil(spikes)[0]
            rates.append(len(spikes[(spikes >= min_time) & (spikes < max_time)])/(max_time - min_time))
        return rates

    def serial_correlation(self, max_lags=50):
        """
            Returns the serial correlation of the interspike intervals.
        
        Args
            max_lags (int, optional): The number of lags to take into account
        Returns
            list of float: the serial correlation as a function of the lag
        """
        scs = []
        for sd in self.__spike_data:
            if sd is None or len(sd) < 100:
                continue
            corr = self.__get_serial_correlation(sd, max_lags=max_lags)
            if corr is not None:
                scs.append(corr)
        return scs

    def circular_std(self):
        """The circular standard deviation of the baseline spikes. The circ. std. is given in radiant.

        Returns:
            list of float: for each run of the baseline RePro there will be one entry.
        """
        circular_stds = []
        for i in range(self.size):
            phases = self.__spike_phases(index=i)
            circular_stds.append(circstd(phases))
        return circular_stds

    @property
    def eod_frequency(self):
        """The average baseline EOD frequency in Hz.

        Returns:
            float: the EOD frequency averaged across runs. Given in Hz.
        """
        eod_frequencies = []
        for i in range(self.size):
            eod, time = self.eod(i)
            xings = zero_crossings(eod, time, interpolate=False)
            eod_frequencies.append(len(xings)/xings[-1])
        return 0.0 if len(eod_frequencies) < 1 else np.mean(eod_frequencies)

    def __spike_phases(self, index=0): # fixme buffer this stuff
        e_times = self.eod_times(index=index)
        eod_period = np.mean(np.diff(e_times))
        phases = np.zeros(len(self.spikes(index)))
        for i, st in enumerate(self.spikes(index)):
            last_eod_index = np.where(e_times <= st)[0]
            if len(last_eod_index) == 0:
                continue
            phases[i] = (st - e_times[last_eod_index[-1]]) / eod_period * 2 * np.pi
        return phases

    def eod_times(self, index=0, interpolate=True):
        """The times of the detected EODs.

        Args:
            index (int, optional): The run of the BaselineActivity RePro. Defaults to 0.
            interpolate (bool, optional): Defines whether a simple threshold mechanism is used or times are interpolated. Defaults to True.

        Returns:
            numpy.ndarray: the eod times.
        """
        if index >= self.size:
            return None
        if len(self.__eod_times) < len(self.__eod_data):
            eod, time = self.eod(index)
            times = zero_crossings(eod, time, interpolate=interpolate)
        else:
            times = self.__eod_times[index]
        return times

    @property
    def dataset(self):
        return self.__dataset

    @property
    def cell(self):
        cells = self.__dataset.cells
        return cells if len(cells) > 1 else cells[0]

    @property
    def subject(self):
        subjects = self.__dataset.subjects
        return subjects if len(subjects) > 1 else subjects[0]

    def spikes(self, index: int=0):
        """Get the spike times of the spikes recorded in the given baseline recording.

        Args:
            index (int, optional): If the baseline activity has been recorded several times, the index can be given. Defaults to 0.

        Returns:
            numpy.adarray: the spike times
        """
        return self.__spike_data[index] if len(self.__spike_data) >= index else None

    def membrane_voltage(self, index: int=0):
        """[summary]

        Args:
            index (int, optional): [description]. Defaults to 0.

        Raises:
            IndexError: [description]

        Returns:
            [type]: [description]
        """
        if index >= self.size:
            raise IndexError("Index %i out of bounds for size %i!" % (index, self.size))
        if not self.__dataset.has_nix:
            print('Sorry, this is not supported for non-nixed datasets. Implement it at '
                  'fishbook.reproclasses.BaselineData.membrane_voltage and send a pull request!')
            return None, None
        else:
            rp = self.__repros[index]
            data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
            f = nix.File.open(data_source, nix.FileMode.ReadOnly)
            b = f.blocks[0]
            t = b.tags[rp.id]
            if not t:
                print("Tag not found!")
            try:
                data = t.retrieve_data("V-1")[:]
                time = np.asarray(t.references["V-1"].dimensions[0].axis(len(data)))
            except:
                data = np.empty(0)
                time = np.empty(0)
            f.close()
            return time, data

    def eod(self, index: int=0):
        """Returns the EOD data for a given run of the BaselineActivity RePro.

        Args:
            index (int, optional): The run index. Defaults to 0.

        Returns:
            numpy.ndarray: The eod trace.
            numpy.ndarray: A matching time axis starting at time zero.

        """
        eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
        time = np.arange(len(eod)) / self.__dataset.samplerate
        return eod, time

    @property
    def burst_index(self):
        """Fraction of spikes that occur in intervals of less than 1.5 times the EOD period.

        Returns:
            list of float: burst indices for each repro run.
        """
        bi = []
        for i, sd in enumerate(self.__spike_data):
            if len(sd) < 2:
                continue
            et = self.eod_times(index=i)
            eod_period = np.mean(np.diff(et))
            isis = np.diff(sd)
            bi.append(np.sum(isis < (1.5 * eod_period))/len(isis))
        return bi

    @property
    def coefficient_of_variation(self):
        """Coefficient of variation of the interspike intervals.

        Returns:
            list of float: for each baseline repro run a single value of the CV.
        """
        cvs = []
        for d in self.__spike_data:
            isis = np.diff(d)
            cvs.append(np.std(isis)/np.mean(isis))
        return cvs

    @property
    def vector_strength(self):
        """The vector strength with which the spikes lock to the fish's own EOD

        Returns:
            list of float: the vector strength calculated separatedly for each repro run.
            list of numpy.ndarray: the spike phases within the EOD period (in radiant).
        """
        vss = []
        spike_phases = []
        for i in range(self.size):
            phases = self.__spike_phases(i)
            ms_sin_alpha = np.mean(np.sin(phases)) ** 2
            ms_cos_alpha = np.mean(np.cos(phases)) ** 2
            vs = np.sqrt(ms_cos_alpha + ms_sin_alpha)
            vss.append(vs)
            spike_phases.append(phases)
        return vss, spike_phases

    @property
    def size(self):
        """The number of times the BaselineActivity RePro was run.

        Returns:
            int: the number of baseline repro runs
        """
        return len(self.__spike_data)

    def __str__(self):
        return "Baseline data of cell %s " % self.__cell.id

    def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_eod_data_from_directory(r, duration)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        try:
            data = t.retrieve_data("EOD")[:]
        except:
            data = np.empty(0)
        f.close()
        return data

    def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray:
        sr = self.__dataset.samplerate
        unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
        eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
        eod = eod[:int(duration * sr)]
        return eod

    def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray:
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spike_data_from_directory(r)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        t = b.tags[r.id]
        if not t:
            print("Tag not found!")
        try:
            data = t.retrieve_data("Spikes-1")[:]
            if data[0] < 0:
                data = data[1:]  # this is related to a nix::RangeDimension bug, should be fixed beyond 1.4.9
        except:
            data = None

        f.close()
        if len(data) < 100:
            data = None
        return data

    def __read_spike_data_from_directory(self, r) -> np.ndarray:
        data = []
        data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
        if os.path.exists(data_source):
            found_run = False
            with open(data_source, 'r') as f:
                l = f.readline()
                while l:
                    if "index" in l:
                        index = int(l.strip("#").strip().split(":")[-1])
                        found_run = index == r.run
                    if l.startswith("#Key") and found_run:
                        data = self.__do_read(f)
                        break
                    l = f.readline()
        if len(data) < 100:
            return None
        return np.asarray(data)

    @staticmethod
    def __do_read(f) -> np.ndarray:
        data = []
        f.readline()
        unit = f.readline().strip("#").strip()
        scale = 0.001 if unit == "ms" else 1
        l = f.readline()
        while l and "#" not in l and len(l.strip()) > 0:
            data.append(float(l.strip())*scale)
            l = f.readline()
        return np.asarray(data)


class FIData:
    """Class representing the data recorded with the relacs FI-Curve repro. The instance will load the data upon
        construction which may take a while.
        FI Data offers convenient access to the spike and local EOD data as well as offers conveince methods to get the
        firing rate and also to fit a Boltzmann function to the the FI curve.
    """
    def __init__(self, dataset: Dataset):
        """Constructor.

        Args:
            fishbook.Dataset: The dataset entity for which the fi curve repro data should be loaded.
        """
        self.__spike_data = []
        self.__contrasts = []
        self.__eod_data = []
        self.__eod_times = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self._get_data()
        if self.size < 1:
            print("No FICurve data was recorded in dataset %s" % self.__dataset.id)

    def _get_data(self):
        if not self.__dataset:
            return
        self.__repros,_  = RePro.find("FICurve", cell_id=self.__cell.id)
        for r in self.__repros:
            sd, c, eods, time = self.__read_spike_data(r)
            if sd is not None and len(sd) > 1:
                self.__spike_data.extend(sd)
                if eods:
                    self.__eod_data.extend(eods)
                self.__contrasts.extend(c)
                if time:
                    self.__eod_times.extend(time)
            else:
                continue

    def __read_spike_data(self, repro: RePro):
        """

        :param repro:
        :return: spike data and the respective contrasts
        """
        if self.__dataset.has_nix:
            return self.__read_spikes_from_nix(repro)
        else:
            return self.__read_spikes_from_directory(repro)

    def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
        r_settings = repro.settings.split("\n")
        s_settings = stimulus.settings.split("\n")
        delay = 0.0
        contrast = 0.0
        for s in r_settings:
            if "delay:" in s:
                delay = float(s.split(":")[-1])
                break
        for s in s_settings:
            if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s:
                contrast = float(s.split(":")[-1])
                break

        start_time = stimulus.start_time - delay
        end_time = stimulus.start_time + stimulus.duration
        eod_da = mtag.references["LocalEOD-1"]
        start_index_eod = eod_da.dimensions[0].index_of(start_time)
        end_index_eod = eod_da.dimensions[0].index_of(end_time)

        local_eod = eod_da[start_index_eod:end_index_eod]
        spikes = self.__all_spikes[(self.__all_spikes >= start_time) & (self.__all_spikes < end_time)] - start_time - delay
        time = np.asarray(eod_da.dimensions[0].axis(end_index_eod - start_index_eod)) - delay
        return spikes, local_eod, time, contrast

    def __read_spikes_from_nix(self, repro: RePro):
        spikes = []
        eods = []
        time = []
        contrasts = []
        stimuli, count = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
        if count == 0:
            return spikes, contrasts, eods, time
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spikes_from_directory(repro)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        self.__all_spikes = b.data_arrays["Spikes-1"][:]
        mt = None
        for i in tqdm(range(count), desc="Loading data"):
            s = stimuli[i]
            if not mt or mt.id != s.multi_tag_id:
                mt = b.multi_tags[s.multi_tag_id]
            sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)

            spikes.append(sp)
            eods.append(eod)
            time.append(t)
            contrasts.append(c)
        f.close()
        return spikes, contrasts, eods, time

    def __do_read_data_block(self, f, l):
        spikes = []
        while len(l.strip()) > 0 and "#" not in l:
            spikes.append(float(l.strip())/1000)
            l = f.readline()
        return spikes, l

    def __read_spikes_from_directory(self, repro: RePro):
        spikes = []
        contrasts = []
        times = []
        print("Warning! Exact reconstruction of stimulus order not possible for old relacs files!")
        data_source = os.path.join(self.__dataset.data_source, "fispikes1.dat")
        delay = 0.0
        pause = 0.0
        for s in repro.settings.split(", "):
            if "pause" in s:
                s = s.split(":")[-1].strip()
                pause = float(s[:-2])/1000 if "ms" in s else float(s[:-1])
            if "delay" in s:
                s = s.split(":")[-1].strip()
                delay = float(s[:-2])/1000 if "ms" in s else float(s[:-1])
        t_start = -delay
        t_end = pause + delay
        time = np.arange(t_start, t_end, 1./repro.dataset.samplerate)
        if os.path.exists(data_source):
            with open(data_source, 'r') as f:
                line = f.readline()
                fish_intensity = None
                stim_intensity = None
                while line:
                    line = line.strip().lower()
                    if "index" in line:
                        fish_intensity = None
                        stim_intensity = None
                    if "intensity = " in line:
                        if "true intensity = " in line:
                            fish_intensity = float(line.split("=")[-1].strip()[:-2])
                        elif "pre" not in line:
                            stim_intensity = float(line.split("=")[-1].strip()[:-2])
                    if len(line) > 0 and "#" not in line:  # data line
                        sp, line = self.__do_read_data_block(f, line)
                        spikes.append(sp)
                        times.append(time)
                        contrasts.append((stim_intensity/fish_intensity-1)*100)
                        continue
                    line = f.readline()
        return spikes, contrasts, None, times

    @property
    def size(self) -> int:
        """
        The number of recorded trials

        returns
             int: the number of trials.
        """
        return len(self.__spike_data)

    def spikes(self, index=-1):
        """The spike times recorded in the specified trial(s)

        Args:
            int, optional: the index of the trial. Default of -1 indicates that all data should be returned.
        
        Returns:
            list of numpy.ndarray: the spike trains.
        """
        if 0 <= index < self.size:
            return self.__spike_data[index]
        else:
            return self.__spike_data

    def eod(self, index=-1):
        """ The local eod (including the stimulus) measurement of the selected trial(s).

        Args:
            int, optional: the index of the trial. Default of -1 indicates that all data should be returned.

        Returns:
            Either two vectors representing time and the local eod or two lists of such vectors
        """
        if len(self.__eod_data) == 0:
            print("EOD data not available for old-style relacs data.")
            return None, None
        if 0 <= index < self.size:
            return self.__eod_times[index], self.__eod_data[index]
        else:
            return self.__eod_times, self.__eod_data

    def contrast(self, index=-1):
        """ The stimulus contrast used in the respective trial(s).

        Args:
            int, optional: the index of the trial. Default of -1 indicates that all data should be returned.
            
        Returns:
            Either a single scalar representing the contrast, or a list of such scalars, one entry for each trial.
        """
        if 0 <= index < self.size:
            return self.__contrasts[index]
        else:
            return self.__contrasts

    def time_axis(self, index=-1):
        """
        Get the time axis of a single trial or a list of time-vectors for all trials.

        :param index: the index of the trial. Default of -1 indicates that all data should be returned.
        :return: Either a single vector representing time, or a list of such vectors, one for each trial.
        """
        if 0 <= index < self.size:
            return self.__eod_times[index]
        else:
            return self.__eod_times

    def rate(self, index=0, kernel_width=0.005):
        """ Returns the firing rate for a single trial. Firing rate estimation using the kernel convolution method.

        Args:
            int, optional: The index of the trial. 0 <= index < size
            float, optional: kernel_width: The width of the gaussian kernel in seconds. Defaults to 0.005 s

        Returns:
            numpy.ndarray: a vector representing time 
            numpy.ndarray: a vector containing the firing rate.
        """
        t = self.time_axis(index)
        sp = self.spikes(index)
        r = spike_times_to_rate(sp, t, kernel_width)
        return t, r

    def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005):
        """
        Extracts the average firing rate within a time window from the averaged across trial firing rate.
        The analysis time window is specified by the start_time and end_time parameters. Firing rate is estimated by
        convolution with a Gaussian kernel of a given width. All parameters are given in 's'.

        :param start_time: the start of the analysis window.
        :param end_time: the end of the analysis window.
        :param kernel_width: standard deviation of the Gaussian kernel used for firing rate estimation.
        :return: object of type BoltzmannFit
        """
        if self.size < 1:
            print("No FICurve data recorded in dataset %s" % self.__dataset.id)
            return None
        contrasts = np.zeros(self.size)
        rates = np.zeros(self.size)
        for i in range(self.size):
            contrasts[i] = np.round(self.contrast(i))
            t, r = self.rate(i, kernel_width)
            rates[i] = np.mean(r[(t >= start_time) & (t < end_time)])

        boltzmann_fit = BoltzmannFit(contrasts, rates)
        return boltzmann_fit


class FileStimulusData:
    """The FileStimulus class provides access to the data recorded and the stimulus presented (if accessible) 
    during runs of the FileStimulus repro. Since the FileStimulus repro can put out any stimulus this class does not 
    provide any further analyses. 

    As any other relacs class it is instantiated with a Dataset entity.
    """
    def __init__(self, dataset: Dataset):
        """
        Constructor.

        Args
            fishbook.Dataset: The dataset entity for which the filestimulus repro data should be loaded.
        """
        self.__spike_data = []
        self.__contrasts = []
        self.__stimulus_files = []
        self.__stimulus_settings = []
        self.__delays = []
        self.__durations = []
        self.__dataset = dataset
        self.__repros = None
        self.__cell = dataset.cells[0]  # Beware: Assumption that there is only a single cell
        self.__all_spikes = None
        self.__stimspikes = None
        self._get_data()

    @property
    def dataset(self):
        return self.__dataset
    
    @property
    def cell(self):
        return self.__cell

    def _get_data(self):
        if not self.__dataset:
            return
        print("Find FileStimulus repro runs in this dataset!")
        self.__repros, _ = RePro.find("FileStimulus", cell_id=self.__cell.id)
        if not self.__dataset.has_nix:
            self.__stimspikes = StimSpikesFile(self.__dataset.data_source)
        for r in self.__repros:
            if self.__dataset.has_nix:
                spikes, contrasts, stims, delays, durations, stim_settings = self.__read_spike_data_from_nix(r)
            else:
                spikes, contrasts, stims, delays, durations, stim_settings = self.__read_spike_data_from_directory(r)
            if spikes is not None and len(spikes) > 0:
                self.__spike_data.extend(spikes)
                self.__contrasts.extend(contrasts)
                self.__stimulus_files.extend(stims)
                self.__delays.extend(delays)
                self.__durations.extend(durations)
                self.__stimulus_settings.extend(stim_settings)
            else:
                continue

    def __find_contrast(self, repro_settings, stimulus_settings, has_nix=True):
        def read_contrast(str, has_nix):
            if has_nix:
                return float(str.split("+")[0]) * 100
            else:
                return float(str[:-1])
        
        contrast = 0.0
        if "project" in repro_settings.keys():
            repro_settings = repro_settings["project"]
        elif "Project" in repro_settings.keys():
            repro_settings = repro_settings["Project"]
        for k in repro_settings.keys():
            if k.lower() == "contrast":
                contrast = read_contrast(repro_settings[k], has_nix)
        
        # fall back to the stimulus settings only for those when the contrast is zero in repro settings, it was probably mutable in relacs
        if contrast < 0.0000001:
            if "project" in repro_settings.keys():
                stimulus_settings = stimulus_settings["project"]
            elif "Project" in stimulus_settings.keys():
                stimulus_settings = stimulus_settings["Project"]
            for k in stimulus_settings.keys():
                if k.lower() == "contrast":
                    contrast = read_contrast(stimulus_settings[k], has_nix)

        return contrast

    def __do_read_spike_data_from_nix(self, mt: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
        spikes = np.empty(0)
        contrast = 0.0
        
        r_settings = yaml.safe_load(repro.settings.replace("\t", ""))
        s_settings = yaml.safe_load(stimulus.settings.replace("\t", ""))
        stim_file = r_settings["file"]
        delay = 0.0
        if "delay:" in map(str.lower, r_settings.keys()):
            delay = float(r_settings["delay"].split(":")[-1])
        start_time = stimulus.start_time - delay
        end_time = stimulus.start_time + mt.extents[stimulus.index]
        duration = float(mt.extents[stimulus.index])
        contrast = self.__find_contrast(r_settings, s_settings, True)
        
        spikes = self.__all_spikes[(self.__all_spikes >= start_time) & (self.__all_spikes < end_time)] - start_time - delay
        
        return spikes, contrast, stim_file, delay, duration
        """
        local_eod = eod_da[start_index_eod:end_index_eod]
        time = np.asarray(eod_da.dimensions[0].axis(end_index_eod - start_index_eod)) - delay
        return spikes, local_eod, time, contrast

        return spikes, contrast, stim_file
        """

    def __read_spike_data_from_nix(self, repro: RePro):
        spikes = []
        contrasts = []
        stim_files = []
        delays = []
        durations = []
        settings = []
        repro_settings = repro.to_dict
        r_settings = yaml.safe_load(repro.settings.replace("\t", ""))
        stimuli, _ = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
        if len(stimuli) == 0:
            return spikes, contrasts, stim_files, [], [], []
        data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
        if not os.path.exists(data_source):
            print("Data not found! Trying from directory")
            return self.__read_spike_data_from_directory(repro)
        f = nix.File.open(data_source, nix.FileMode.ReadOnly)
        b = f.blocks[0]
        self.__all_spikes = b.data_arrays["Spikes-1"][:]
        mt = None
        for i in tqdm(range(len(stimuli)), desc="Loading data"):
            s = stimuli[i]
            if not mt or mt.id != s.multi_tag_id:
                mt = b.multi_tags[s.multi_tag_id]
            
            sp, c, stim, delay, duration = self.__do_read_spike_data_from_nix(mt, s, repro)
            if len(sp) > 5:
                spikes.append(sp)
                contrasts.append(c)
                stim_files.append(stim)
                delays.append(delay)
                durations.append(duration)
                stim_settings = s.to_dict
                settings.append({"stimulus": stim_settings, "repro": repro_settings})
        f.close()
        return spikes, contrasts, stim_files, delays, durations, settings

    def __read_spike_data_from_directory(self, repro: RePro):
        stimuli, _ = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
        spikes = []
        contrasts = []
        stim_files = []
        delays = []
        durations = []
        settings = []
        r_settings = yaml.safe_load(repro.settings.replace("\t", ""))
        r_settings = r_settings["project"] if "project" in r_settings.keys() else r_settings
        repro_settings = repro.to_dict
        for s in stimuli:
            s_settings = yaml.safe_load(s.settings.replace("\t", ""))
            s_settings = s_settings["project"] if "project" in s_settings.keys() else s_settings
            contrast = self.__find_contrast(r_settings, s_settings, False)
            dur, sp = self.__stimspikes.get(s.run, s.index)
            if not sp or len(sp) < 5:
                continue
            
            if "duration" in s_settings.keys():
                duration = float(s_settings["duration"][:-2]) / 1000
            else:
                duration = dur
            contrasts.append(contrast)
            delays.append(float(r_settings["before"][:-2]) / 1000)
            durations.append(duration)
            stim_files.append(s_settings["file"])
            spikes.append(sp)
            settings.append({"stimulus": s.to_dict, "repro": repro_settings})

        return spikes, contrasts, stim_files, delays, durations, settings

    def read_stimulus(self, index=0):
        pass
    
    @property
    def size(self):
        return len(self.__spike_data)
    
    def spikes(self, index=-1):
        if index == -1:
            return self.__spike_data
        elif index >= 0 and index < self.size:
            return self.__spike_data[index]
        else:
            raise IndexError("FileStimulusData: index %i out of bounds for spike data of size %i" % (index, self.size))

    def stimulus_settings(self, index=0):
        if index >= self.size:
            raise IndexError("FileStimulusData: index %i is out of bounds for spike data of size %i" %(index, self.size))
        return self.__stimulus_settings[index]

    def contrast(self, index=-1):
        if index == -1:
            return self.__contrasts
        elif index >=0 and index < self.size:
            return self.__contrasts[index]
        else:
            raise IndexError("FileStimulusData: index %i out of bounds for contrasts data of size %i" % (index, self.size))

    def stimulus_files(self, index=-1):
        if index == -1:
            return self.__stimulus_files
        elif index >=0 and index < self.size:
            return self.__stimulus_files[index]
        else:
            raise IndexError("FileStimulusData: index %i out of bounds for contrasts data of size %i" % (index, self.size))

    def trial_duration(self, index=-1):
        if index == -1:
            return self.__durations
        elif index >=0 and index < self.size:
            return self.__durations[index]
        else:
            raise IndexError("FileStimulusData: index %i out of bounds for contrasts data of size %i" % (index, self.size))

    def time_axis(self, index=-1):
        """
        Get the time axis of a single trial or a list of time-vectors for all trials.

        :param index: the index of the trial. Default of -1 indicates that all data should be returned.
        :return: Either a single vector representing time, or a list of such vectors, one for each trial.
        """
        if 0 <= index < self.size:
            delay = self.__delays[index]
            duration = self.__durations[index]
            return np.arange(delay, duration, 1./self.__dataset.samplerate)
        elif index == -1:
            axes = []
            for i in range(self.size):
                delay = self.__delays[i]
                duration = self.__durations[i]
                axes.append(np.arange(delay, duration, 1./self.__dataset.samplerate))
            return axes
        else:
            raise IndexError("FileStimulusData: index %i out of bounds for time_axes of size %i" % (index, self.size))
    
    def rate(self, index=-1, kernel_width=0.005):
        """[summary]

        Args:
            index (int, optional): [description]. Defaults to -1.
            kernel_width (float, optional): [description]. Defaults to 0.005.

        Raises:
            IndexError: [description]

        Returns:
            [type]: [description]
        """
        if index == -1:
            time_axes = []
            rates = []
            for i in range(self.size):
                t = self.time_axis(i)
                spikes = self.spikes(i)
                r = spike_times_to_rate(spikes, t, kernel_width)
                time_axes.append(t)
                rates.append(r)
            return time_axes, rates
        elif index >= 0 and index < self.size:
            t = self.time_axis(index)
            spikes = self.spikes(index)
            r = spike_times_to_rate(spikes, t, kernel_width)
            return t, r
        else:
            raise IndexError("FileStimulusData: index %i out of bounds for time_axes of size %i" % (index, self.size))


if __name__ == "__main__":

    # dataset = Dataset(dataset_id='2011-06-14-ag')
    # dataset = Dataset(dataset_id="2018-09-13-ac-invivo-1")
    dataset = Dataset(dataset_id='2013-04-18-ac-invivo-1')
    fi_curve = FileStimulusData(dataset)
    embed()