from fishbook.fishbook import Dataset, RePro, Stimulus import numpy as np import nixio as nix from scipy.stats import circstd import os import subprocess from tqdm import tqdm from IPython import embed def _zero_crossings(x, t, interpolate=False): dt = t[1] - t[0] x_shift = np.roll(x, 1) x_shift[0] = 0.0 xings = np.where((x >= 0 ) & (x_shift < 0))[0] crossings = np.zeros(len(xings)) if interpolate: for i, tf in enumerate(xings): if x[tf] > 0.001: m = (x[tf] - x[tf-1])/dt crossings[i] = t[tf] - x[tf]/m elif x[tf] < -0.001: m = (x[tf + 1] - x[tf]) / dt crossings[i] = t[tf] - x[tf]/m else: crossings[i] = t[tf] else: crossings = t[xings] return crossings def _unzip_if_needed(dataset, tracename='trace-1.raw'): file_name = os.path.join(dataset, tracename) if os.path.exists(file_name): return if os.path.exists(file_name + '.gz'): print("\tunzip: %s" % tracename) subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")]) def gaussian_kernel(sigma, dt): x = np.arange(-4. * sigma, 4. * sigma, dt) y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma return y class BaselineData: def __init__(self, dataset: Dataset): self.__spike_data = [] self.__eod_data = [] self.__eod_times = [] self.__dataset = dataset self.__repros = None self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell self._get_data() def _get_data(self): if not self.__dataset: return self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id) for r in self.__repros: sd = self.__read_spike_data(r) if sd is not None and len(sd) > 1: self.__spike_data.append(sd) else: continue self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1])) def valid(self): # fixme implement me! pass def __read_spike_data(self, r: RePro): if self.__dataset.has_nix: return self.__read_spike_data_from_nix(r) else: return self.__read_spike_data_from_directory(r) def __read_eod_data(self, r: RePro, duration): if self.__dataset.has_nix: return self.__read_eod_data_from_nix(r, duration) else: return self.__read_eod_data_from_directory(r, duration) def __get_serial_correlation(self, times, max_lags=50): if times is None or len(times) < max_lags: return None isis = np.diff(times) unbiased = isis - np.mean(isis, 0) norm = sum(unbiased ** 2) a_corr = np.correlate(unbiased, unbiased, "same") / norm a_corr = a_corr[int(len(a_corr) / 2):] return a_corr[:max_lags] def serial_correlation(self, max_lags=50): """ return the serial correlation for the the spike train provided by spike_times. @param max_lags: The number of lags to take into account @return: the serial correlation as a function of the lag """ scs = [] for sd in self.__spike_data: if sd is None or len(sd) < 100: continue corr = self.__get_serial_correlation(sd, max_lags=max_lags) if corr is not None: scs.append(corr) return scs def circular_std(self): cstds = [] for i in range(self.size): phases = self.__spike_phases(index=i) cstds.append(circstd(phases)) return cstds def eod_frequency(self): eodfs = [] for i in range(self.size): eod, time = self.eod(i) xings = _zero_crossings(eod, time, interpolate=False) eodfs.append(len(xings)/xings[-1]) return 0.0 if len(eodfs) < 1 else np.mean(eodfs) def __spike_phases(self, index=0): # fixme buffer this stuff etimes = self.eod_times(index=index) eod_period = np.mean(np.diff(etimes)) phases = np.zeros(len(self.spikes(index))) for i, st in enumerate(self.spikes(index)): last_eod_index = np.where(etimes <= st)[0] if len(last_eod_index) == 0: continue phases[i] = (st - etimes[last_eod_index[-1]]) / eod_period * 2 * np.pi return phases def eod_times(self, index=0, interpolate=True): if index >= self.size: return None if len(self.__eod_times) < len(self.__eod_data): eod, time = self.eod(index) etimes = _zero_crossings(eod, time, interpolate=interpolate) else: etimes = self.__eod_times[index] return etimes @property def dataset(self): return self.__dataset @property def cell(self): cells = self.__dataset.cells return cells if len(cells) > 1 else cells[0] @property def subject(self): subjects = self.__dataset.subjects return subjects if len(subjects) > 1 else subjects[0] def spikes(self, index: int=0): return self.__spike_data[index] if len(self.__spike_data) >= index else None def eod(self, index: int=0): eod = self.__eod_data[index] if len(self.__eod_data) >= index else None time = np.arange(len(eod)) / self.__dataset.samplerate return eod, time @property def burst_index(self): bi = [] for i, sd in enumerate(self.__spike_data): if len(sd) < 2: continue et = self.eod_times(index=i) eod_period = np.mean(np.diff(et)) isis = np.diff(sd) bi.append(np.sum(isis < (1.5 * eod_period))/len(isis)) return bi @property def coefficient_of_variation(self): cvs = [] for d in self.__spike_data: isis = np.diff(d) cvs.append(np.std(isis)/np.mean(isis)) return cvs @property def vector_strength(self): vss = [] spike_phases = [] for i, sd in enumerate(self.__spike_data): phases = self.__spike_phases(i) ms_sin_alpha = np.mean(np.sin(phases)) ** 2 ms_cos_alpha = np.mean(np.cos(phases)) ** 2 vs = np.sqrt(ms_cos_alpha + ms_sin_alpha) vss.append(vs) spike_phases.append(phases) return vss, spike_phases @property def size(self): return len(self.__spike_data) def __str__(self): str = "Baseline data of cell %s " % self.__cell.id def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray: data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix") if not os.path.exists(data_source): print("Data not found! Trying from directory") return self.__read_eod_data_from_directory(r, duration) f = nix.File.open(data_source, nix.FileMode.ReadOnly) b = f.blocks[0] t = b.tags[r.id] if not t: print("Tag not found!") try: data = t.retrieve_data("EOD")[:] except: data = np.empty(); f.close() return data def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray: sr = self.__dataset.samplerate _unzip_if_needed(self.__dataset.data_source, "trace-2.raw") eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32) eod = eod[:int(duration * sr)] return eod def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray: data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix") if not os.path.exists(data_source): print("Data not found! Trying from directory") return self.__read_spike_data_from_directory(r) f = nix.File.open(data_source, nix.FileMode.ReadOnly) b = f.blocks[0] t = b.tags[r.id] if not t: print("Tag not found!") try: data = t.retrieve_data("Spikes-1")[:] except: data = None f.close() if len(data) < 100: data = None return data def __read_spike_data_from_directory(self, r) -> np.ndarray: data = [] data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat") if os.path.exists(data_source): found_run = False with open(data_source, 'r') as f: l = f.readline() while l: if "index" in l: index = int(l.strip("#").strip().split(":")[-1]) found_run = index == r.run if l.startswith("#Key") and found_run: data = self.__do_read(f) break l = f.readline() if len(data) < 100: return None return np.asarray(data) @staticmethod def __do_read(f) -> np.ndarray: data = [] f.readline() unit = f.readline().strip("#").strip() scale = 0.001 if unit == "ms" else 1 l = f.readline() while l and "#" not in l and len(l.strip()) > 0: data.append(float(l.strip())*scale) l = f.readline() return np.asarray(data) class FIData: def __init__(self, dataset: Dataset): self.__spike_data = [] self.__contrasts = [] self.__eod_data = [] self.__eod_times = [] self.__dataset = dataset self.__repros = None self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell self._get_data() pass def _get_data(self): if not self.__dataset: return self.__repros = RePro.find("FICurve", cell_id=self.__cell.id) for r in self.__repros: sd, c, eods, time = self.__read_spike_data(r) if sd is not None and len(sd) > 1: self.__spike_data.extend(sd) self.__eod_data.extend(eods) self.__contrasts.extend(c) self.__eod_times.extend(time) else: continue def __read_spike_data(self, repro: RePro): """ :param repro: :return: spike data and the respective contrasts """ if self.__dataset.has_nix: return self.__read_spikes_from_nix(repro) else: print("Sorry, so far only from nix!!!") pass return None, None, None, None def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro): r_settings = repro.settings.split("\n") s_settings = stimulus.settings.split("\n") delay = 0.0 contrast = 0.0 for s in r_settings: if "delay:" in s: delay = float(s.split(":")[-1]) break for s in s_settings: if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s: contrast = float(s.split(":")[-1]) break start_time = stimulus.start_time - delay end_time = stimulus.start_time + stimulus.duration spikes_da = mtag.references["Spikes-1"] eod_da = mtag.references["LocalEOD-1"] start_index_spikes = spikes_da.dimensions[0].index_of(start_time) end_index_spikes = spikes_da.dimensions[0].index_of(end_time) start_index_eod = eod_da.dimensions[0].index_of(start_time) end_index_eod = eod_da.dimensions[0].index_of(end_time) local_eod = eod_da[start_index_eod:end_index_eod] spikes = spikes_da[start_index_spikes:end_index_spikes] - start_time time = np.asarray(eod_da.dimensions[0].axis(len(local_eod))) - delay return spikes, local_eod, time, contrast def __read_spikes_from_nix(self, repro: RePro): spikes = [] eods = [] time = [] contrasts = [] stimuli = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id) if len(stimuli) == 0: return spikes, contrasts, eods, time data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix") if not os.path.exists(data_source): print("Data not found! Trying from directory") return self.__read_spike_data_from_directory(repro) f = nix.File.open(data_source, nix.FileMode.ReadOnly) b = f.blocks[0] mt = None for i in tqdm(range(len(stimuli))): s = stimuli[i] if not mt or mt.id != s.multi_tag_id: mt = b.multi_tags[s.multi_tag_id] sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro) spikes.append(sp) eods.append(eod) time.append(t) contrasts.append(c) f.close() return spikes, contrasts, eods, time @property def size(self): return len(self.__spike_data) def spikes(self, index=-1): if 0 <= index < self.size: return self.__spike_data[index] else: return self.__spike_data def eod(self, index=-1): if 0 <= index < self.size: return self.__eod_times[index], self.__eod_data[index] else: return self.__eod_times, self.__eod_data def contrast(self, index=-1): if 0 <= index < self.size: return self.__contrasts[index] else: return self.__contrasts def time_axis(self, index=-1): if 0 <= index < self.size: return self.__eod_times[index] else: return self.__eod_times def rate(self, index=0, kernel_width=0.005): """ Returns the firing rate for a single trial. :param index: The index of the trial. 0 <= index < size :param kernel_width: The width of the gaussian kernel in seconds :return: tuple of time and rate """ t = self.time_axis(index) dt = np.mean(np.diff(t)) sp = self.spikes(index) binary = np.zeros(t.shape) spike_indices = (sp / dt).astype(int) binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1 g = gaussian_kernel(kernel_width, dt) rate = np.convolve(binary, g, mode='same') return t, rate def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005): """ :param start_time: :param end_time: :param kernel_width: :return: object of type BoltzmannFit """ contrasts = np.zeros(self.size) rates = np.zeros(self.size) for i in range(self.size): contrasts[i] = np.round(self.contrast(i)) t, r = self.rate(i, kernel_width) rates[i] = np.mean(r[(t >= start_time) & (t < end_time)]) boltzmann_fit = BoltzmannFit(contrasts, rates) return boltzmann_fit if __name__ == "__main__": #dataset = Dataset(dataset_id='2011-06-14-ag') dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1") # dataset = Dataset(dataset_id='2018-11-09-aa-invivo-1') baseline = BaselineData(dataset) embed()