This commit is contained in:
Jan Grewe 2019-09-24 21:49:41 +02:00
commit 00086a4545

View File

@ -1,12 +1,34 @@
from fishbook.fishbook import Dataset, RePro from fishbook.fishbook import Dataset, RePro
import numpy as np import numpy as np
import nixio as nix import nixio as nix
from scipy.stats import circstd
import os import os
import subprocess import subprocess
from IPython import embed from IPython import embed
def _zero_crossings(x, t, interpolate=False):
dt = t[1] - t[0]
x_shift = np.roll(x, 1)
x_shift[0] = 0.0
xings = np.where((x >= 0 ) & (x_shift < 0))[0]
crossings = np.zeros(len(xings))
if interpolate:
for i, tf in enumerate(xings):
if x[tf] > 0.001:
m = (x[tf] - x[tf-1])/dt
crossings[i] = t[tf] - x[tf]/m
elif x[tf] < -0.001:
m = (x[tf + 1] - x[tf]) / dt
crossings[i] = t[tf] - x[tf]/m
else:
crossings[i] = t[tf]
else:
crossings = t[xings]
return crossings
def _unzip_if_needed(dataset, tracename='trace-1.raw'): def _unzip_if_needed(dataset, tracename='trace-1.raw'):
file_name = os.path.join(dataset, tracename) file_name = os.path.join(dataset, tracename)
if os.path.exists(file_name): if os.path.exists(file_name):
@ -21,6 +43,7 @@ class BaselineData:
def __init__(self, dataset:Dataset): def __init__(self, dataset:Dataset):
self.__spike_data = [] self.__spike_data = []
self.__eod_data = [] self.__eod_data = []
self.__eod_times = []
self.__dataset = dataset self.__dataset = dataset
self.__repros = None self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
@ -31,9 +54,17 @@ class BaselineData:
return return
self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id) self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id)
for r in self.__repros: for r in self.__repros:
self.__spike_data.append(self.__read_spike_data(r)) sd = self.__read_spike_data(r)
if sd is not None and len(sd) > 1:
self.__spike_data.append(sd)
else:
continue
self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1])) self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))
def valid(self):
# fixme implement me!
pass
def __read_spike_data(self, r:RePro): def __read_spike_data(self, r:RePro):
if self.__dataset.has_nix: if self.__dataset.has_nix:
return self.__read_spike_data_from_nix(r) return self.__read_spike_data_from_nix(r)
@ -46,6 +77,66 @@ class BaselineData:
else: else:
return self.__read_eod_data_from_directory(r, duration) return self.__read_eod_data_from_directory(r, duration)
def __get_serial_correlation(self, times, max_lags=50):
if times is None or len(times) < max_lags:
return None
isis = np.diff(times)
unbiased = isis - np.mean(isis, 0)
norm = sum(unbiased ** 2)
a_corr = np.correlate(unbiased, unbiased, "same") / norm
a_corr = a_corr[int(len(a_corr) / 2):]
return a_corr[:max_lags]
def serial_correlation(self, max_lags=50):
"""
return the serial correlation for the the spike train provided by spike_times.
@param max_lags: The number of lags to take into account
@return: the serial correlation as a function of the lag
"""
scs = []
for sd in self.__spike_data:
if sd is None or len(sd) < 100:
continue
corr = self.__get_serial_correlation(sd, max_lags=max_lags)
if corr is not None:
scs.append(corr)
return scs
def circular_std(self):
cstds = []
for i in range(self.size):
phases = self.__spike_phases(index=i)
cstds.append(circstd(phases))
def eod_frequency(self):
eodfs = []
for i in range(self.size):
eod, time = self.eod(i)
xings = _zero_crossings(eod, time, interpolate=False)
eodfs.append(len(xings)/xings[-1])
return 0.0 if len(eodfs) < 1 else np.mean(eodfs)
def __spike_phases(self, index=0): # fixme buffer this stuff
etimes = self.eod_times(index=index)
eod_period = np.mean(np.diff(etimes))
phases = np.zeros(len(self.spikes(index)))
for i, st in enumerate(self.spikes(index)):
last_eod_index = np.where(etimes <= st)[0]
if len(last_eod_index) == 0:
continue
phases[i] = (st - etimes[last_eod_index[-1]]) / eod_period * 2 * np.pi
return phases
def eod_times(self, index=0, interpolate=True):
if index >= self.size:
return None
if len(self.__eod_times) < len(self.__eod_data):
eod, time = self.eod(index)
etimes = _zero_crossings(eod, time, interpolate=interpolate)
else:
etimes = self.__eod_times[index]
return etimes
@property @property
def dataset(self): def dataset(self):
return self.__dataset return self.__dataset
@ -60,26 +151,46 @@ class BaselineData:
subjects = self.__dataset.subjects subjects = self.__dataset.subjects
return subjects if len(subjects) > 1 else subjects[0] return subjects if len(subjects) > 1 else subjects[0]
def spike_data(self, index:int=0): def spikes(self, index:int=0):
return self.__spike_data[index] if len(self.__spike_data) >= index else None return self.__spike_data[index] if len(self.__spike_data) >= index else None
def eod_data(self, index:int=0): def eod(self, index:int=0):
eod = self.__eod_data[index] if len(self.__eod_data) >= index else None eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
time = np.arange(len(eod)) / self.__dataset.samplerate time = np.arange(len(eod)) / self.__dataset.samplerate
return eod, time return eod, time
@property
def burst_index(self):
bi = []
for i, sd in enumerate(self.__spike_data):
if len(sd) < 2:
continue
et = self.eod_times(index=i)
eod_period = np.mean(np.diff(et))
isis = np.diff(sd)
bi.append(np.sum(isis < (1.5 * eod_period))/len(isis))
return bi
@property @property
def coefficient_of_variation(self): def coefficient_of_variation(self):
cvs = [] cvs = []
for d in self.__spike_data: for d in self.__spike_data:
isis = np.diff(d) isis = np.diff(d)
cvs.append(np.std(isis)/np.mean(d=isis)) cvs.append(np.std(isis)/np.mean(isis))
return cvs return cvs
@property @property
def vector_strength(self): def vector_strength(self):
vss = [] vss = []
return vss spike_phases = []
for i, sd in enumerate(self.__spike_data):
phases = self.__spike_phases(i)
ms_sin_alpha = np.mean(np.sin(phases)) ** 2
ms_cos_alpha = np.mean(np.cos(phases)) ** 2
vs = np.sqrt(ms_cos_alpha + ms_sin_alpha)
vss.append(vs)
spike_phases.append(phases)
return vss, spike_phases
@property @property
def size(self): def size(self):
@ -98,7 +209,10 @@ class BaselineData:
t = b.tags[r.id] t = b.tags[r.id]
if not t: if not t:
print("Tag not found!") print("Tag not found!")
data = t.retrieve_data("EOD")[:] try:
data = t.retrieve_data("EOD")[:]
except:
data = np.empty();
f.close() f.close()
return data return data
@ -119,8 +233,14 @@ class BaselineData:
t = b.tags[r.id] t = b.tags[r.id]
if not t: if not t:
print("Tag not found!") print("Tag not found!")
data = t.retrieve_data("Spikes-1")[:] try:
data = t.retrieve_data("Spikes-1")[:]
except:
data = None
f.close() f.close()
if len(data) < 100:
data = None
return data return data
@ -139,7 +259,9 @@ class BaselineData:
data = self.__do_read(f) data = self.__do_read(f)
break break
l = f.readline() l = f.readline()
return data if len(data) < 100:
return None
return np.asarray(data)
def __do_read(self, f)->np.ndarray: def __do_read(self, f)->np.ndarray:
data = [] data = []