from fishbook.fishbook import Dataset, RePro import numpy as np import nixio as nix from scipy.stats import circstd import os import subprocess from IPython import embed def _zero_crossings(x, t, interpolate=False): dt = t[1] - t[0] x_shift = np.roll(x, 1) x_shift[0] = 0.0 xings = np.where((x >= 0 ) & (x_shift < 0))[0] crossings = np.zeros(len(xings)) if interpolate: for i, tf in enumerate(xings): if x[tf] > 0.001: m = (x[tf] - x[tf-1])/dt crossings[i] = t[tf] - x[tf]/m elif x[tf] < -0.001: m = (x[tf + 1] - x[tf]) / dt crossings[i] = t[tf] - x[tf]/m else: crossings[i] = t[tf] else: crossings = t[xings] return crossings def _unzip_if_needed(dataset, tracename='trace-1.raw'): file_name = os.path.join(dataset, tracename) if os.path.exists(file_name): return if os.path.exists(file_name + '.gz'): print("\tunzip: %s" % tracename) subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")]) class BaselineData: def __init__(self, dataset:Dataset): self.__spike_data = [] self.__eod_data = [] self.__eod_times = [] self.__dataset = dataset self.__repros = None self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell self._get_data() def _get_data(self): if not self.__dataset: return self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id) for r in self.__repros: sd = self.__read_spike_data(r) if sd is not None and len(sd) > 1: self.__spike_data.append(sd) else: continue self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1])) def valid(self): # fixme implement me! pass def __read_spike_data(self, r:RePro): if self.__dataset.has_nix: return self.__read_spike_data_from_nix(r) else: return self.__read_spike_data_from_directory(r) def __read_eod_data(self, r:RePro, duration): if self.__dataset.has_nix: return self.__read_eod_data_from_nix(r, duration) else: return self.__read_eod_data_from_directory(r, duration) def __get_serial_correlation(self, times, max_lags=50): if times is None or len(times) < max_lags: return None isis = np.diff(times) unbiased = isis - np.mean(isis, 0) norm = sum(unbiased ** 2) a_corr = np.correlate(unbiased, unbiased, "same") / norm a_corr = a_corr[int(len(a_corr) / 2):] return a_corr[:max_lags] def serial_correlation(self, max_lags=50): """ return the serial correlation for the the spike train provided by spike_times. @param max_lags: The number of lags to take into account @return: the serial correlation as a function of the lag """ scs = [] for sd in self.__spike_data: if sd is None or len(sd) < 100: continue corr = self.__get_serial_correlation(sd, max_lags=max_lags) if corr is not None: scs.append(corr) return scs def circular_std(self): cstds = [] for i in range(self.size): phases = self.__spike_phases(index=i) cstds.append(circstd(phases)) return cstds def eod_frequency(self): eodfs = [] for i in range(self.size): eod, time = self.eod(i) xings = _zero_crossings(eod, time, interpolate=False) eodfs.append(len(xings)/xings[-1]) return 0.0 if len(eodfs) < 1 else np.mean(eodfs) def __spike_phases(self, index=0): # fixme buffer this stuff etimes = self.eod_times(index=index) eod_period = np.mean(np.diff(etimes)) phases = np.zeros(len(self.spikes(index))) for i, st in enumerate(self.spikes(index)): last_eod_index = np.where(etimes <= st)[0] if len(last_eod_index) == 0: continue phases[i] = (st - etimes[last_eod_index[-1]]) / eod_period * 2 * np.pi return phases def eod_times(self, index=0, interpolate=True): if index >= self.size: return None if len(self.__eod_times) < len(self.__eod_data): eod, time = self.eod(index) etimes = _zero_crossings(eod, time, interpolate=interpolate) else: etimes = self.__eod_times[index] return etimes @property def dataset(self): return self.__dataset @property def cell(self): cells = self.__dataset.cells return cells if len(cells) > 1 else cells[0] @property def subject(self): subjects = self.__dataset.subjects return subjects if len(subjects) > 1 else subjects[0] def spikes(self, index:int=0): return self.__spike_data[index] if len(self.__spike_data) >= index else None def eod(self, index:int=0): eod = self.__eod_data[index] if len(self.__eod_data) >= index else None time = np.arange(len(eod)) / self.__dataset.samplerate return eod, time @property def burst_index(self): bi = [] for i, sd in enumerate(self.__spike_data): if len(sd) < 2: continue et = self.eod_times(index=i) eod_period = np.mean(np.diff(et)) isis = np.diff(sd) bi.append(np.sum(isis < (1.5 * eod_period))/len(isis)) return bi @property def coefficient_of_variation(self): cvs = [] for d in self.__spike_data: isis = np.diff(d) cvs.append(np.std(isis)/np.mean(isis)) return cvs @property def vector_strength(self): vss = [] spike_phases = [] for i, sd in enumerate(self.__spike_data): phases = self.__spike_phases(i) ms_sin_alpha = np.mean(np.sin(phases)) ** 2 ms_cos_alpha = np.mean(np.cos(phases)) ** 2 vs = np.sqrt(ms_cos_alpha + ms_sin_alpha) vss.append(vs) spike_phases.append(phases) return vss, spike_phases @property def size(self): return len(self.__spike_data) def __str__(self): str = "Baseline data of cell %s " % self.__cell.id def __read_eod_data_from_nix(self, r:RePro, duration)->np.ndarray: data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix") if not os.path.exists(data_source): print("Data not found! Trying from directory") return self.__read_eod_data_from_directory(r, duration) f = nix.File.open(data_source, nix.FileMode.ReadOnly) b = f.blocks[0] t = b.tags[r.id] if not t: print("Tag not found!") try: data = t.retrieve_data("EOD")[:] except: data = np.empty(); f.close() return data def __read_eod_data_from_directory(self, r:RePro, duration)->np.ndarray: sr = self.__dataset.samplerate _unzip_if_needed(self.__dataset.data_source, "trace-2.raw") eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32) eod = eod[:int(duration * sr)] return eod def __read_spike_data_from_nix(self, r:RePro)->np.ndarray: data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix") if not os.path.exists(data_source): print("Data not found! Trying from directory") return self.__read_spike_data_from_directory(r) f = nix.File.open(data_source, nix.FileMode.ReadOnly) b = f.blocks[0] t = b.tags[r.id] if not t: print("Tag not found!") try: data = t.retrieve_data("Spikes-1")[:] except: data = None f.close() if len(data) < 100: data = None return data def __read_spike_data_from_directory(self, r)->np.ndarray: data = [] data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat") if os.path.exists(data_source): found_run = False with open(data_source, 'r') as f: l = f.readline() while l: if "index" in l: index = int(l.strip("#").strip().split(":")[-1]) found_run = index == r.run if l.startswith("#Key") and found_run: data = self.__do_read(f) break l = f.readline() if len(data) < 100: return None return np.asarray(data) def __do_read(self, f)->np.ndarray: data = [] f.readline() unit = f.readline().strip("#").strip() scale = 0.001 if unit == "ms" else 1 l = f.readline() while l and "#" not in l and len(l.strip()) > 0: data.append(float(l.strip())*scale) l = f.readline() return np.asarray(data) if __name__ == "__main__": dataset = Dataset(dataset_id='2011-06-14-ag') # dataset = Dataset(dataset_id='2018-11-09-aa-invivo-1') baseline = BaselineData(dataset) embed()