From 4f2a114c9e2a599a1ecc752e00320e639a0184d7 Mon Sep 17 00:00:00 2001 From: Jan Grewe Date: Tue, 24 Sep 2019 21:44:49 +0200 Subject: [PATCH] many new function on BaselineData class --- fishbook/reproclasses.py | 138 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 130 insertions(+), 8 deletions(-) diff --git a/fishbook/reproclasses.py b/fishbook/reproclasses.py index f6a685c..a763778 100644 --- a/fishbook/reproclasses.py +++ b/fishbook/reproclasses.py @@ -1,12 +1,34 @@ from fishbook.fishbook import Dataset, RePro import numpy as np import nixio as nix +from scipy.stats import circstd import os import subprocess from IPython import embed +def _zero_crossings(x, t, interpolate=False): + dt = t[1] - t[0] + x_shift = np.roll(x, 1) + x_shift[0] = 0.0 + xings = np.where((x >= 0 ) & (x_shift < 0))[0] + crossings = np.zeros(len(xings)) + if interpolate: + for i, tf in enumerate(xings): + if x[tf] > 0.001: + m = (x[tf] - x[tf-1])/dt + crossings[i] = t[tf] - x[tf]/m + elif x[tf] < -0.001: + m = (x[tf + 1] - x[tf]) / dt + crossings[i] = t[tf] - x[tf]/m + else: + crossings[i] = t[tf] + else: + crossings = t[xings] + return crossings + + def _unzip_if_needed(dataset, tracename='trace-1.raw'): file_name = os.path.join(dataset, tracename) if os.path.exists(file_name): @@ -21,6 +43,7 @@ class BaselineData: def __init__(self, dataset:Dataset): self.__spike_data = [] self.__eod_data = [] + self.__eod_times = [] self.__dataset = dataset self.__repros = None self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell @@ -31,9 +54,17 @@ class BaselineData: return self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id) for r in self.__repros: - self.__spike_data.append(self.__read_spike_data(r)) + sd = self.__read_spike_data(r) + if sd is not None and len(sd) > 1: + self.__spike_data.append(sd) + else: + continue self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1])) + def valid(self): + # fixme implement me! + pass + def __read_spike_data(self, r:RePro): if self.__dataset.has_nix: return self.__read_spike_data_from_nix(r) @@ -46,6 +77,66 @@ class BaselineData: else: return self.__read_eod_data_from_directory(r, duration) + def __get_serial_correlation(self, times, max_lags=50): + if times is None or len(times) < max_lags: + return None + isis = np.diff(times) + unbiased = isis - np.mean(isis, 0) + norm = sum(unbiased ** 2) + a_corr = np.correlate(unbiased, unbiased, "same") / norm + a_corr = a_corr[int(len(a_corr) / 2):] + return a_corr[:max_lags] + + def serial_correlation(self, max_lags=50): + """ + return the serial correlation for the the spike train provided by spike_times. + @param max_lags: The number of lags to take into account + @return: the serial correlation as a function of the lag + """ + scs = [] + for sd in self.__spike_data: + if sd is None or len(sd) < 100: + continue + corr = self.__get_serial_correlation(sd, max_lags=max_lags) + if corr is not None: + scs.append(corr) + return scs + + def circular_std(self): + cstds = [] + for i in range(self.size): + phases = self.__spike_phases(index=i) + cstds.append(circstd(phases)) + + def eod_frequency(self): + eodfs = [] + for i in range(self.size): + eod, time = self.eod(i) + xings = _zero_crossings(eod, time, interpolate=False) + eodfs.append(len(xings)/xings[-1]) + return 0.0 if len(eodfs) < 1 else np.mean(eodfs) + + def __spike_phases(self, index=0): # fixme buffer this stuff + etimes = self.eod_times(index=index) + eod_period = np.mean(np.diff(etimes)) + phases = np.zeros(len(self.spikes(index))) + for i, st in enumerate(self.spikes(index)): + last_eod_index = np.where(etimes <= st)[0] + if len(last_eod_index) == 0: + continue + phases[i] = (st - etimes[last_eod_index[-1]]) / eod_period * 2 * np.pi + return phases + + def eod_times(self, index=0, interpolate=True): + if index >= self.size: + return None + if len(self.__eod_times) < len(self.__eod_data): + eod, time = self.eod(index) + etimes = _zero_crossings(eod, time, interpolate=interpolate) + else: + etimes = self.__eod_times[index] + return etimes + @property def dataset(self): return self.__dataset @@ -60,26 +151,46 @@ class BaselineData: subjects = self.__dataset.subjects return subjects if len(subjects) > 1 else subjects[0] - def spike_data(self, index:int=0): + def spikes(self, index:int=0): return self.__spike_data[index] if len(self.__spike_data) >= index else None - def eod_data(self, index:int=0): + def eod(self, index:int=0): eod = self.__eod_data[index] if len(self.__eod_data) >= index else None time = np.arange(len(eod)) / self.__dataset.samplerate return eod, time + @property + def burst_index(self): + bi = [] + for i, sd in enumerate(self.__spike_data): + if len(sd) < 2: + continue + et = self.eod_times(index=i) + eod_period = np.mean(np.diff(et)) + isis = np.diff(sd) + bi.append(np.sum(isis < (1.5 * eod_period))/len(isis)) + return bi + @property def coefficient_of_variation(self): cvs = [] for d in self.__spike_data: isis = np.diff(d) - cvs.append(np.std(isis)/np.mean(d=isis)) + cvs.append(np.std(isis)/np.mean(isis)) return cvs @property def vector_strength(self): vss = [] - return vss + spike_phases = [] + for i, sd in enumerate(self.__spike_data): + phases = self.__spike_phases(i) + ms_sin_alpha = np.mean(np.sin(phases)) ** 2 + ms_cos_alpha = np.mean(np.cos(phases)) ** 2 + vs = np.sqrt(ms_cos_alpha + ms_sin_alpha) + vss.append(vs) + spike_phases.append(phases) + return vss, spike_phases @property def size(self): @@ -98,7 +209,10 @@ class BaselineData: t = b.tags[r.id] if not t: print("Tag not found!") - data = t.retrieve_data("EOD")[:] + try: + data = t.retrieve_data("EOD")[:] + except: + data = np.empty(); f.close() return data @@ -119,8 +233,14 @@ class BaselineData: t = b.tags[r.id] if not t: print("Tag not found!") - data = t.retrieve_data("Spikes-1")[:] + try: + data = t.retrieve_data("Spikes-1")[:] + except: + data = None + f.close() + if len(data) < 100: + data = None return data @@ -139,7 +259,9 @@ class BaselineData: data = self.__do_read(f) break l = f.readline() - return data + if len(data) < 100: + return None + return np.asarray(data) def __do_read(self, f)->np.ndarray: data = []