diff --git a/fishbook/frontend/relacs_classes.py b/fishbook/frontend/relacs_classes.py index 505b7c5..2c2dc6f 100644 --- a/fishbook/frontend/relacs_classes.py +++ b/fishbook/frontend/relacs_classes.py @@ -1,5 +1,5 @@ from .frontend_classes import Dataset, RePro, Stimulus -from .util import BoltzmannFit, unzip_if_needed, gaussian_kernel, zero_crossings, spike_times_to_rate +from .util import BoltzmannFit, unzip_if_needed, gaussian_kernel, zero_crossings, spike_times_to_rate, StimSpikesFile import numpy as np import nixio as nix from scipy.stats import circstd @@ -638,18 +638,23 @@ class FileStimulusData: self.__repros = None self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell self.__all_spikes = None + self.__stimspikes = None self._get_data() + def _get_data(self): if not self.__dataset: return + print("Find FileStimulus repro runs in this dataset!") self.__repros, _ = RePro.find("FileStimulus", cell_id=self.__cell.id) + if not self.__dataset.has_nix: + self.__stimspikes = StimSpikesFile(self.__dataset.data_source) for r in self.__repros: if self.__dataset.has_nix: spikes, contrasts, stims, delays, durations = self.__read_spike_data_from_nix(r) else: - spikes, contrasts, stims, delays, durations = self.__read_spike_data_from_directory(r) # TODO - if spikes is not None and len(spikes) > 1: + spikes, contrasts, stims, delays, durations = self.__read_spike_data_from_directory(r) + if spikes is not None and len(spikes) > 0: self.__spike_data.extend(spikes) self.__contrasts.extend(contrasts) self.__stimuli.extend(stims) @@ -659,19 +664,35 @@ class FileStimulusData: continue def __find_contrast(self, repro_settings, stimulus_settings, has_nix=True): + def read_contrast(str, has_nix): + if has_nix: + return float(str.split("+")[0]) * 100 + else: + return float(str[:-1]) + contrast = 0.0 + if "project" in repro_settings.keys(): + repro_settings = repro_settings["project"] + elif "Project" in repro_settings.keys(): + repro_settings = repro_settings["Project"] for k in repro_settings.keys(): if k.lower() == "contrast": - contrast = float(repro_settings[k].split("+")[0]) * (100 if has_nix else 1) - - if contrast < 0.00001: + contrast = read_contrast(repro_settings[k], has_nix) + + # fall back to the stimulus settings only for those when the contrast is zero in repro settings, it was probably mutable in relacs + if contrast < 0.0000001: + if "project" in repro_settings.keys(): + stimulus_settings = stimulus_settings["project"] + elif "Project" in stimulus_settings.keys(): + stimulus_settings = stimulus_settings["Project"] for k in stimulus_settings.keys(): if k.lower() == "contrast": - contrast = float(stimulus_settings[k].split("+")[0]) * (100 if has_nix else 1) + contrast = read_contrast(stimulus_settings[k], has_nix) + return contrast def __do_read_spike_data_from_nix(self, mt: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro): - spikes = None + spikes = np.empty(0) contrast = 0.0 r_settings = yaml.safe_load(repro.settings.replace("\t", "")) @@ -719,22 +740,38 @@ class FileStimulusData: mt = b.multi_tags[s.multi_tag_id] sp, c, stim, delay, duration = self.__do_read_spike_data_from_nix(mt, s, repro) - spikes.append(sp) - contrasts.append(c) - stim_files.append(stim) - delays.append(delay) - durations.append(duration) + if len(sp) > 0: + spikes.append(sp) + contrasts.append(c) + stim_files.append(stim) + delays.append(delay) + durations.append(duration) f.close() return spikes, contrasts, stim_files, delays, contrasts def __read_spike_data_from_directory(self, repro: RePro): - print("not yet, my friend!") + stimuli, _ = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id) spikes = [] contrasts = [] stim_files = [] delays = [] durations = [] - + r_settings = yaml.safe_load(repro.settings.replace("\t", "")) + r_settings = r_settings["project"] if "project" in r_settings.keys() else r_settings + for s in stimuli: + s_settings = yaml.safe_load(s.settings.replace("\t", "")) + s_settings = s_settings["project"] if "project" in s_settings.keys() else s_settings + contrast = self.__find_contrast(r_settings, s_settings, False) + duration = float(s_settings["duration"][:-2]) / 1000 + sp = self.__stimspikes.get(s.run, s.index) + if not sp or len(sp) < 1: + continue + contrasts.append(contrast) + delays.append(float(r_settings["before"][:-2]) / 1000) + durations.append(duration) + stim_files.append(s_settings["file"]) + spikes.append(sp) + return spikes, contrasts, stim_files, delays, durations def read_stimulus(self, index=0): @@ -768,6 +805,14 @@ class FileStimulusData: else: raise IndexError("FileStimulusData: index %i out of bounds for contrasts data of size %i" % (index, self.size)) + def trial_duration(self, index=-1): + if index == -1: + return self.__durations + elif index >=0 and index < self.size: + return self.__durations[index] + else: + raise IndexError("FileStimulusData: index %i out of bounds for contrasts data of size %i" % (index, self.size)) + def time_axis(self, index=-1): """ Get the time axis of a single trial or a list of time-vectors for all trials. @@ -790,6 +835,18 @@ class FileStimulusData: raise IndexError("FileStimulusData: index %i out of bounds for time_axes of size %i" % (index, self.size)) def rate(self, index=-1, kernel_width=0.005): + """[summary] + + Args: + index (int, optional): [description]. Defaults to -1. + kernel_width (float, optional): [description]. Defaults to 0.005. + + Raises: + IndexError: [description] + + Returns: + [type]: [description] + """ if index == -1: time_axes = [] rates = [] @@ -808,11 +865,11 @@ class FileStimulusData: else: raise IndexError("FileStimulusData: index %i out of bounds for time_axes of size %i" % (index, self.size)) - + if __name__ == "__main__": # dataset = Dataset(dataset_id='2011-06-14-ag') - dataset = Dataset(dataset_id="2018-09-13-ac-invivo-1") - # dataset = Dataset(dataset_id='2013-04-18-ac') + # dataset = Dataset(dataset_id="2018-09-13-ac-invivo-1") + dataset = Dataset(dataset_id='2013-04-18-ac-invivo-1') fi_curve = FileStimulusData(dataset) embed() diff --git a/fishbook/frontend/util.py b/fishbook/frontend/util.py index bd2fa36..8e33e52 100644 --- a/fishbook/frontend/util.py +++ b/fishbook/frontend/util.py @@ -1,5 +1,8 @@ import numpy as np +import os +import subprocess from scipy.optimize import curve_fit +from IPython import embed def spike_times_to_rate(spike_times, time_axis, kernel_width=0.005): """Convert spike times to a rate by means of kernel convolution. A Gaussian kernel of the desired width is used. @@ -121,7 +124,7 @@ class BoltzmannFit: """ The underlying Boltzmann function. .. math:: - f(x) = y_max / \exp{-slope*(x-inflection} + f(x) = y_max / exp{-slope*(x-inflection} :param x: The x values. :param y_max: The maximum value. @@ -181,3 +184,44 @@ class BoltzmannFit: xvalues = self.__x_sorted return self.boltzmann(xvalues, *self.__fit_params) +class StimSpikesFile: + + def __init__(self, filename): + if "stimspikes-1.dat" not in filename: + filename += os.path.join(os.path.sep, "stimspikes1.dat") + if not os.path.exists(filename): + raise ValueError("StimSpikesFile: the given file %s does not exist!" % filename) + self._filename = filename + self._data_map = self.__parse_file(filename) + + def __parse_file(self, filename): + with open(filename, 'r') as f: + lines = f.readlines() + + index_map = {} + trial_data = [] + index = 0 + trial = 0 + + for l in lines: + l = l.strip() + if "index:" in l: + if len(trial_data) > 0: + index_map[(index, trial)] = trial_data + trial_data = [] + index = int(l[1:].strip().split(":")[-1]) + if "trial:" in l: + if len(trial_data) > 0: + index_map[(index, trial)] = trial_data + trial_data = [] + trial = int(l[1:].strip().split(":")[-1]) + if len(l) > 0 and "#" not in l: + trial_data.append(float(l)/1000) + + return index_map + + def get(self, run_index, trial_index): + if tuple([run_index, trial_index]) not in self._data_map.keys(): + print("Data not found for run %i and trial %i:" % (run_index, trial_index)) + return None + return self._data_map[(run_index, trial_index)] \ No newline at end of file diff --git a/setup.py b/setup.py index c36518d..96945ec 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ assert(sys.version_info >= (3, 0)) __version__ = 0.1 # exec(open('fishbook/version.py').read()) -requires = ['datajoint', 'nixio', 'numpy', 'PyYAML', 'scipy', 'backports-datetime-fromisoformat'] +requires = ['datajoint', 'nixio', 'numpy', 'PyYAML', 'scipy', 'tqdm', 'yaml', 'backports-datetime-fromisoformat'] print(find_packages(exclude=['contrib', 'doc', 'tests*'])) setup(name='fishbook',