This commit is contained in:
Jan Grewe 2019-09-24 21:49:41 +02:00
commit 00086a4545

View File

@ -1,12 +1,34 @@
from fishbook.fishbook import Dataset, RePro
import numpy as np
import nixio as nix
from scipy.stats import circstd
import os
import subprocess
from IPython import embed
def _zero_crossings(x, t, interpolate=False):
dt = t[1] - t[0]
x_shift = np.roll(x, 1)
x_shift[0] = 0.0
xings = np.where((x >= 0 ) & (x_shift < 0))[0]
crossings = np.zeros(len(xings))
if interpolate:
for i, tf in enumerate(xings):
if x[tf] > 0.001:
m = (x[tf] - x[tf-1])/dt
crossings[i] = t[tf] - x[tf]/m
elif x[tf] < -0.001:
m = (x[tf + 1] - x[tf]) / dt
crossings[i] = t[tf] - x[tf]/m
else:
crossings[i] = t[tf]
else:
crossings = t[xings]
return crossings
def _unzip_if_needed(dataset, tracename='trace-1.raw'):
file_name = os.path.join(dataset, tracename)
if os.path.exists(file_name):
@ -21,6 +43,7 @@ class BaselineData:
def __init__(self, dataset:Dataset):
self.__spike_data = []
self.__eod_data = []
self.__eod_times = []
self.__dataset = dataset
self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
@ -31,9 +54,17 @@ class BaselineData:
return
self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id)
for r in self.__repros:
self.__spike_data.append(self.__read_spike_data(r))
sd = self.__read_spike_data(r)
if sd is not None and len(sd) > 1:
self.__spike_data.append(sd)
else:
continue
self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))
def valid(self):
# fixme implement me!
pass
def __read_spike_data(self, r:RePro):
if self.__dataset.has_nix:
return self.__read_spike_data_from_nix(r)
@ -46,6 +77,66 @@ class BaselineData:
else:
return self.__read_eod_data_from_directory(r, duration)
def __get_serial_correlation(self, times, max_lags=50):
if times is None or len(times) < max_lags:
return None
isis = np.diff(times)
unbiased = isis - np.mean(isis, 0)
norm = sum(unbiased ** 2)
a_corr = np.correlate(unbiased, unbiased, "same") / norm
a_corr = a_corr[int(len(a_corr) / 2):]
return a_corr[:max_lags]
def serial_correlation(self, max_lags=50):
"""
return the serial correlation for the the spike train provided by spike_times.
@param max_lags: The number of lags to take into account
@return: the serial correlation as a function of the lag
"""
scs = []
for sd in self.__spike_data:
if sd is None or len(sd) < 100:
continue
corr = self.__get_serial_correlation(sd, max_lags=max_lags)
if corr is not None:
scs.append(corr)
return scs
def circular_std(self):
cstds = []
for i in range(self.size):
phases = self.__spike_phases(index=i)
cstds.append(circstd(phases))
def eod_frequency(self):
eodfs = []
for i in range(self.size):
eod, time = self.eod(i)
xings = _zero_crossings(eod, time, interpolate=False)
eodfs.append(len(xings)/xings[-1])
return 0.0 if len(eodfs) < 1 else np.mean(eodfs)
def __spike_phases(self, index=0): # fixme buffer this stuff
etimes = self.eod_times(index=index)
eod_period = np.mean(np.diff(etimes))
phases = np.zeros(len(self.spikes(index)))
for i, st in enumerate(self.spikes(index)):
last_eod_index = np.where(etimes <= st)[0]
if len(last_eod_index) == 0:
continue
phases[i] = (st - etimes[last_eod_index[-1]]) / eod_period * 2 * np.pi
return phases
def eod_times(self, index=0, interpolate=True):
if index >= self.size:
return None
if len(self.__eod_times) < len(self.__eod_data):
eod, time = self.eod(index)
etimes = _zero_crossings(eod, time, interpolate=interpolate)
else:
etimes = self.__eod_times[index]
return etimes
@property
def dataset(self):
return self.__dataset
@ -60,26 +151,46 @@ class BaselineData:
subjects = self.__dataset.subjects
return subjects if len(subjects) > 1 else subjects[0]
def spike_data(self, index:int=0):
def spikes(self, index:int=0):
return self.__spike_data[index] if len(self.__spike_data) >= index else None
def eod_data(self, index:int=0):
def eod(self, index:int=0):
eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
time = np.arange(len(eod)) / self.__dataset.samplerate
return eod, time
@property
def burst_index(self):
bi = []
for i, sd in enumerate(self.__spike_data):
if len(sd) < 2:
continue
et = self.eod_times(index=i)
eod_period = np.mean(np.diff(et))
isis = np.diff(sd)
bi.append(np.sum(isis < (1.5 * eod_period))/len(isis))
return bi
@property
def coefficient_of_variation(self):
cvs = []
for d in self.__spike_data:
isis = np.diff(d)
cvs.append(np.std(isis)/np.mean(d=isis))
cvs.append(np.std(isis)/np.mean(isis))
return cvs
@property
def vector_strength(self):
vss = []
return vss
spike_phases = []
for i, sd in enumerate(self.__spike_data):
phases = self.__spike_phases(i)
ms_sin_alpha = np.mean(np.sin(phases)) ** 2
ms_cos_alpha = np.mean(np.cos(phases)) ** 2
vs = np.sqrt(ms_cos_alpha + ms_sin_alpha)
vss.append(vs)
spike_phases.append(phases)
return vss, spike_phases
@property
def size(self):
@ -98,7 +209,10 @@ class BaselineData:
t = b.tags[r.id]
if not t:
print("Tag not found!")
data = t.retrieve_data("EOD")[:]
try:
data = t.retrieve_data("EOD")[:]
except:
data = np.empty();
f.close()
return data
@ -119,8 +233,14 @@ class BaselineData:
t = b.tags[r.id]
if not t:
print("Tag not found!")
data = t.retrieve_data("Spikes-1")[:]
try:
data = t.retrieve_data("Spikes-1")[:]
except:
data = None
f.close()
if len(data) < 100:
data = None
return data
@ -139,7 +259,9 @@ class BaselineData:
data = self.__do_read(f)
break
l = f.readline()
return data
if len(data) < 100:
return None
return np.asarray(data)
def __do_read(self, f)->np.ndarray:
data = []