From 7ab965fe377572918313597cbdafe68e9263b79e Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Mon, 7 Oct 2019 14:55:13 +0200 Subject: [PATCH] [reproclasses] new Boltzmann fit class --- fishbook/reproclasses.py | 92 +++++++++++++++++++++++++++++++--------- 1 file changed, 72 insertions(+), 20 deletions(-) diff --git a/fishbook/reproclasses.py b/fishbook/reproclasses.py index 3d155ae..83f0cb4 100644 --- a/fishbook/reproclasses.py +++ b/fishbook/reproclasses.py @@ -4,6 +4,7 @@ import nixio as nix from scipy.stats import circstd import os import subprocess +from tqdm import tqdm from IPython import embed @@ -38,9 +39,15 @@ def _unzip_if_needed(dataset, tracename='trace-1.raw'): subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")]) +def gaussian_kernel(sigma, dt): + x = np.arange(-4. * sigma, 4. * sigma, dt) + y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma + return y + + class BaselineData: - def __init__(self, dataset:Dataset): + def __init__(self, dataset: Dataset): self.__spike_data = [] self.__eod_data = [] self.__eod_times = [] @@ -65,13 +72,13 @@ class BaselineData: # fixme implement me! pass - def __read_spike_data(self, r:RePro): + def __read_spike_data(self, r: RePro): if self.__dataset.has_nix: return self.__read_spike_data_from_nix(r) else: return self.__read_spike_data_from_directory(r) - def __read_eod_data(self, r:RePro, duration): + def __read_eod_data(self, r: RePro, duration): if self.__dataset.has_nix: return self.__read_eod_data_from_nix(r, duration) else: @@ -152,10 +159,10 @@ class BaselineData: subjects = self.__dataset.subjects return subjects if len(subjects) > 1 else subjects[0] - def spikes(self, index:int=0): + def spikes(self, index: int=0): return self.__spike_data[index] if len(self.__spike_data) >= index else None - def eod(self, index:int=0): + def eod(self, index: int=0): eod = self.__eod_data[index] if len(self.__eod_data) >= index else None time = np.arange(len(eod)) / self.__dataset.samplerate return eod, time @@ -200,7 +207,7 @@ class BaselineData: def __str__(self): str = "Baseline data of cell %s " % self.__cell.id - def __read_eod_data_from_nix(self, r:RePro, duration)->np.ndarray: + def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray: data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix") if not os.path.exists(data_source): print("Data not found! Trying from directory") @@ -217,14 +224,14 @@ class BaselineData: f.close() return data - def __read_eod_data_from_directory(self, r:RePro, duration)->np.ndarray: + def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray: sr = self.__dataset.samplerate _unzip_if_needed(self.__dataset.data_source, "trace-2.raw") eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32) eod = eod[:int(duration * sr)] return eod - def __read_spike_data_from_nix(self, r:RePro)->np.ndarray: + def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray: data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix") if not os.path.exists(data_source): print("Data not found! Trying from directory") @@ -244,8 +251,7 @@ class BaselineData: data = None return data - - def __read_spike_data_from_directory(self, r)->np.ndarray: + def __read_spike_data_from_directory(self, r) -> np.ndarray: data = [] data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat") if os.path.exists(data_source): @@ -264,7 +270,8 @@ class BaselineData: return None return np.asarray(data) - def __do_read(self, f)->np.ndarray: + @staticmethod + def __do_read(f) -> np.ndarray: data = [] f.readline() unit = f.readline().strip("#").strip() @@ -277,7 +284,7 @@ class BaselineData: class FIData: - def __init__(self, dataset:Dataset): + def __init__(self, dataset: Dataset): self.__spike_data = [] self.__contrasts = [] self.__eod_data = [] @@ -302,7 +309,7 @@ class FIData: else: continue - def __read_spike_data(self, repro:RePro): + def __read_spike_data(self, repro: RePro): """ :param repro: @@ -315,7 +322,7 @@ class FIData: pass return None, None, None, None - def __do_read_spike_data_from_nix(self, mtag:nix.pycore.MultiTag, stimulus:Stimulus, repro: RePro): + def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro): r_settings = repro.settings.split("\n") s_settings = stimulus.settings.split("\n") delay = 0.0 @@ -344,7 +351,7 @@ class FIData: return spikes, local_eod, time, contrast - def __read_spikes_from_nix(self, repro:RePro): + def __read_spikes_from_nix(self, repro: RePro): spikes = [] eods = [] time = [] @@ -360,7 +367,8 @@ class FIData: b = f.blocks[0] mt = None - for s in stimuli: + for i in tqdm(range(len(stimuli))): + s = stimuli[i] if not mt or mt.id != s.multi_tag_id: mt = b.multi_tags[s.multi_tag_id] sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro) @@ -370,30 +378,74 @@ class FIData: time.append(t) contrasts.append(c) f.close() - return spikes, contrasts, eods, time + return spikes, contrasts, eods, time @property def size(self): return len(self.__spike_data) def spikes(self, index=-1): - if 0 < index < self.size: + if 0 <= index < self.size: return self.__spike_data[index] else: return self.__spike_data def eod(self, index=-1): - if 0 < index < self.size: + if 0 <= index < self.size: return self.__eod_times[index], self.__eod_data[index] else: return self.__eod_times, self.__eod_data def contrast(self, index=-1): - if 0 < index < self.size: + if 0 <= index < self.size: return self.__contrasts[index] else: return self.__contrasts + def time_axis(self, index=-1): + if 0 <= index < self.size: + return self.__eod_times[index] + else: + return self.__eod_times + + def rate(self, index=0, kernel_width=0.005): + """ + Returns the firing rate for a single trial. + + :param index: The index of the trial. 0 <= index < size + :param kernel_width: The width of the gaussian kernel in seconds + :return: tuple of time and rate + """ + t = self.time_axis(index) + dt = np.mean(np.diff(t)) + sp = self.spikes(index) + binary = np.zeros(t.shape) + spike_indices = (sp / dt).astype(int) + binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1 + g = gaussian_kernel(kernel_width, dt) + rate = np.convolve(binary, g, mode='same') + return t, rate + + def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005): + """ + + :param start_time: + :param end_time: + :param kernel_width: + :return: object of type BoltzmannFit + """ + contrasts = np.zeros(self.size) + rates = np.zeros(self.size) + for i in range(self.size): + contrasts[i] = np.round(self.contrast(i)) + t, r = self.rate(i, kernel_width) + rates[i] = np.mean(r[(t >= start_time) & (t < end_time)]) + + boltzmann_fit = BoltzmannFit(contrasts, rates) + return boltzmann_fit + + + if __name__ == "__main__": #dataset = Dataset(dataset_id='2011-06-14-ag') dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1")