fishBook/fishbook/reproclasses.py

454 lines
15 KiB
Python

from fishbook.fishbook import Dataset, RePro, Stimulus
import numpy as np
import nixio as nix
from scipy.stats import circstd
import os
import subprocess
from tqdm import tqdm
from IPython import embed
def _zero_crossings(x, t, interpolate=False):
dt = t[1] - t[0]
x_shift = np.roll(x, 1)
x_shift[0] = 0.0
xings = np.where((x >= 0 ) & (x_shift < 0))[0]
crossings = np.zeros(len(xings))
if interpolate:
for i, tf in enumerate(xings):
if x[tf] > 0.001:
m = (x[tf] - x[tf-1])/dt
crossings[i] = t[tf] - x[tf]/m
elif x[tf] < -0.001:
m = (x[tf + 1] - x[tf]) / dt
crossings[i] = t[tf] - x[tf]/m
else:
crossings[i] = t[tf]
else:
crossings = t[xings]
return crossings
def _unzip_if_needed(dataset, tracename='trace-1.raw'):
file_name = os.path.join(dataset, tracename)
if os.path.exists(file_name):
return
if os.path.exists(file_name + '.gz'):
print("\tunzip: %s" % tracename)
subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])
def gaussian_kernel(sigma, dt):
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
class BaselineData:
def __init__(self, dataset: Dataset):
self.__spike_data = []
self.__eod_data = []
self.__eod_times = []
self.__dataset = dataset
self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
self._get_data()
def _get_data(self):
if not self.__dataset:
return
self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id)
for r in self.__repros:
sd = self.__read_spike_data(r)
if sd is not None and len(sd) > 1:
self.__spike_data.append(sd)
else:
continue
self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))
def valid(self):
# fixme implement me!
pass
def __read_spike_data(self, r: RePro):
if self.__dataset.has_nix:
return self.__read_spike_data_from_nix(r)
else:
return self.__read_spike_data_from_directory(r)
def __read_eod_data(self, r: RePro, duration):
if self.__dataset.has_nix:
return self.__read_eod_data_from_nix(r, duration)
else:
return self.__read_eod_data_from_directory(r, duration)
def __get_serial_correlation(self, times, max_lags=50):
if times is None or len(times) < max_lags:
return None
isis = np.diff(times)
unbiased = isis - np.mean(isis, 0)
norm = sum(unbiased ** 2)
a_corr = np.correlate(unbiased, unbiased, "same") / norm
a_corr = a_corr[int(len(a_corr) / 2):]
return a_corr[:max_lags]
def serial_correlation(self, max_lags=50):
"""
return the serial correlation for the the spike train provided by spike_times.
@param max_lags: The number of lags to take into account
@return: the serial correlation as a function of the lag
"""
scs = []
for sd in self.__spike_data:
if sd is None or len(sd) < 100:
continue
corr = self.__get_serial_correlation(sd, max_lags=max_lags)
if corr is not None:
scs.append(corr)
return scs
def circular_std(self):
cstds = []
for i in range(self.size):
phases = self.__spike_phases(index=i)
cstds.append(circstd(phases))
return cstds
def eod_frequency(self):
eodfs = []
for i in range(self.size):
eod, time = self.eod(i)
xings = _zero_crossings(eod, time, interpolate=False)
eodfs.append(len(xings)/xings[-1])
return 0.0 if len(eodfs) < 1 else np.mean(eodfs)
def __spike_phases(self, index=0): # fixme buffer this stuff
etimes = self.eod_times(index=index)
eod_period = np.mean(np.diff(etimes))
phases = np.zeros(len(self.spikes(index)))
for i, st in enumerate(self.spikes(index)):
last_eod_index = np.where(etimes <= st)[0]
if len(last_eod_index) == 0:
continue
phases[i] = (st - etimes[last_eod_index[-1]]) / eod_period * 2 * np.pi
return phases
def eod_times(self, index=0, interpolate=True):
if index >= self.size:
return None
if len(self.__eod_times) < len(self.__eod_data):
eod, time = self.eod(index)
etimes = _zero_crossings(eod, time, interpolate=interpolate)
else:
etimes = self.__eod_times[index]
return etimes
@property
def dataset(self):
return self.__dataset
@property
def cell(self):
cells = self.__dataset.cells
return cells if len(cells) > 1 else cells[0]
@property
def subject(self):
subjects = self.__dataset.subjects
return subjects if len(subjects) > 1 else subjects[0]
def spikes(self, index: int=0):
return self.__spike_data[index] if len(self.__spike_data) >= index else None
def eod(self, index: int=0):
eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
time = np.arange(len(eod)) / self.__dataset.samplerate
return eod, time
@property
def burst_index(self):
bi = []
for i, sd in enumerate(self.__spike_data):
if len(sd) < 2:
continue
et = self.eod_times(index=i)
eod_period = np.mean(np.diff(et))
isis = np.diff(sd)
bi.append(np.sum(isis < (1.5 * eod_period))/len(isis))
return bi
@property
def coefficient_of_variation(self):
cvs = []
for d in self.__spike_data:
isis = np.diff(d)
cvs.append(np.std(isis)/np.mean(isis))
return cvs
@property
def vector_strength(self):
vss = []
spike_phases = []
for i, sd in enumerate(self.__spike_data):
phases = self.__spike_phases(i)
ms_sin_alpha = np.mean(np.sin(phases)) ** 2
ms_cos_alpha = np.mean(np.cos(phases)) ** 2
vs = np.sqrt(ms_cos_alpha + ms_sin_alpha)
vss.append(vs)
spike_phases.append(phases)
return vss, spike_phases
@property
def size(self):
return len(self.__spike_data)
def __str__(self):
str = "Baseline data of cell %s " % self.__cell.id
def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_eod_data_from_directory(r, duration)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
t = b.tags[r.id]
if not t:
print("Tag not found!")
try:
data = t.retrieve_data("EOD")[:]
except:
data = np.empty();
f.close()
return data
def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray:
sr = self.__dataset.samplerate
_unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
eod = eod[:int(duration * sr)]
return eod
def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_spike_data_from_directory(r)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
t = b.tags[r.id]
if not t:
print("Tag not found!")
try:
data = t.retrieve_data("Spikes-1")[:]
except:
data = None
f.close()
if len(data) < 100:
data = None
return data
def __read_spike_data_from_directory(self, r) -> np.ndarray:
data = []
data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
if os.path.exists(data_source):
found_run = False
with open(data_source, 'r') as f:
l = f.readline()
while l:
if "index" in l:
index = int(l.strip("#").strip().split(":")[-1])
found_run = index == r.run
if l.startswith("#Key") and found_run:
data = self.__do_read(f)
break
l = f.readline()
if len(data) < 100:
return None
return np.asarray(data)
@staticmethod
def __do_read(f) -> np.ndarray:
data = []
f.readline()
unit = f.readline().strip("#").strip()
scale = 0.001 if unit == "ms" else 1
l = f.readline()
while l and "#" not in l and len(l.strip()) > 0:
data.append(float(l.strip())*scale)
l = f.readline()
return np.asarray(data)
class FIData:
def __init__(self, dataset: Dataset):
self.__spike_data = []
self.__contrasts = []
self.__eod_data = []
self.__eod_times = []
self.__dataset = dataset
self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
self._get_data()
pass
def _get_data(self):
if not self.__dataset:
return
self.__repros = RePro.find("FICurve", cell_id=self.__cell.id)
for r in self.__repros:
sd, c, eods, time = self.__read_spike_data(r)
if sd is not None and len(sd) > 1:
self.__spike_data.extend(sd)
self.__eod_data.extend(eods)
self.__contrasts.extend(c)
self.__eod_times.extend(time)
else:
continue
def __read_spike_data(self, repro: RePro):
"""
:param repro:
:return: spike data and the respective contrasts
"""
if self.__dataset.has_nix:
return self.__read_spikes_from_nix(repro)
else:
print("Sorry, so far only from nix!!!")
pass
return None, None, None, None
def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
r_settings = repro.settings.split("\n")
s_settings = stimulus.settings.split("\n")
delay = 0.0
contrast = 0.0
for s in r_settings:
if "delay:" in s:
delay = float(s.split(":")[-1])
break
for s in s_settings:
if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s:
contrast = float(s.split(":")[-1])
break
start_time = stimulus.start_time - delay
end_time = stimulus.start_time + stimulus.duration
spikes_da = mtag.references["Spikes-1"]
eod_da = mtag.references["LocalEOD-1"]
start_index_spikes = spikes_da.dimensions[0].index_of(start_time)
end_index_spikes = spikes_da.dimensions[0].index_of(end_time)
start_index_eod = eod_da.dimensions[0].index_of(start_time)
end_index_eod = eod_da.dimensions[0].index_of(end_time)
local_eod = eod_da[start_index_eod:end_index_eod]
spikes = spikes_da[start_index_spikes:end_index_spikes] - start_time
time = np.asarray(eod_da.dimensions[0].axis(len(local_eod))) - delay
return spikes, local_eod, time, contrast
def __read_spikes_from_nix(self, repro: RePro):
spikes = []
eods = []
time = []
contrasts = []
stimuli = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
if len(stimuli) == 0:
return spikes, contrasts, eods, time
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_spike_data_from_directory(repro)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
mt = None
for i in tqdm(range(len(stimuli))):
s = stimuli[i]
if not mt or mt.id != s.multi_tag_id:
mt = b.multi_tags[s.multi_tag_id]
sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)
spikes.append(sp)
eods.append(eod)
time.append(t)
contrasts.append(c)
f.close()
return spikes, contrasts, eods, time
@property
def size(self):
return len(self.__spike_data)
def spikes(self, index=-1):
if 0 <= index < self.size:
return self.__spike_data[index]
else:
return self.__spike_data
def eod(self, index=-1):
if 0 <= index < self.size:
return self.__eod_times[index], self.__eod_data[index]
else:
return self.__eod_times, self.__eod_data
def contrast(self, index=-1):
if 0 <= index < self.size:
return self.__contrasts[index]
else:
return self.__contrasts
def time_axis(self, index=-1):
if 0 <= index < self.size:
return self.__eod_times[index]
else:
return self.__eod_times
def rate(self, index=0, kernel_width=0.005):
"""
Returns the firing rate for a single trial.
:param index: The index of the trial. 0 <= index < size
:param kernel_width: The width of the gaussian kernel in seconds
:return: tuple of time and rate
"""
t = self.time_axis(index)
dt = np.mean(np.diff(t))
sp = self.spikes(index)
binary = np.zeros(t.shape)
spike_indices = (sp / dt).astype(int)
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
g = gaussian_kernel(kernel_width, dt)
rate = np.convolve(binary, g, mode='same')
return t, rate
def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005):
"""
:param start_time:
:param end_time:
:param kernel_width:
:return: object of type BoltzmannFit
"""
contrasts = np.zeros(self.size)
rates = np.zeros(self.size)
for i in range(self.size):
contrasts[i] = np.round(self.contrast(i))
t, r = self.rate(i, kernel_width)
rates[i] = np.mean(r[(t >= start_time) & (t < end_time)])
boltzmann_fit = BoltzmannFit(contrasts, rates)
return boltzmann_fit
if __name__ == "__main__":
#dataset = Dataset(dataset_id='2011-06-14-ag')
dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1")
# dataset = Dataset(dataset_id='2018-11-09-aa-invivo-1')
baseline = BaselineData(dataset)
embed()