fishBook/fishbook/frontend/relacs_classes.py
2020-07-22 13:35:12 +02:00

780 lines
28 KiB
Python

from fishbook.fishbook import Dataset, RePro, Stimulus
import numpy as np
import nixio as nix
from scipy.stats import circstd
from scipy.optimize import curve_fit
import os
import subprocess
from tqdm import tqdm
from IPython import embed
def _zero_crossings(x, t, interpolate=False):
"""get the times at which a signal x
Args:
x ([type]): [description]
t ([type]): [description]
interpolate (bool, optional): [description]. Defaults to False.
Returns:
[type]: [description]
"""
dt = t[1] - t[0]
x_shift = np.roll(x, 1)
x_shift[0] = 0.0
xings = np.where((x >= 0 ) & (x_shift < 0))[0]
crossings = np.zeros(len(xings))
if interpolate:
for i, tf in enumerate(xings):
if x[tf] > 0.001:
m = (x[tf] - x[tf-1])/dt
crossings[i] = t[tf] - x[tf]/m
elif x[tf] < -0.001:
m = (x[tf + 1] - x[tf]) / dt
crossings[i] = t[tf] - x[tf]/m
else:
crossings[i] = t[tf]
else:
crossings = t[xings]
return crossings
def _unzip_if_needed(dataset, tracename='trace-1.raw'):
file_name = os.path.join(dataset, tracename)
if os.path.exists(file_name):
return
if os.path.exists(file_name + '.gz'):
print("\tunzip: %s" % tracename)
subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])
def gaussian_kernel(sigma, dt):
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
class BaselineData:
"""
Class representing the Baseline data that has been recorded within a given Dataset.
"""
def __init__(self, dataset: Dataset):
self.__spike_data = []
self.__eod_data = []
self.__eod_times = []
self.__dataset = dataset
self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
self._get_data()
def _get_data(self):
if not self.__dataset:
return
self.__repros = RePro.find("BaselineActivity", cell_id=self.__cell.id)
for i in tqdm(range(len(self.__repros)), desc="loading data"):
r = self.__repros[i]
sd = self.__read_spike_data(r)
if sd is not None and len(sd) > 1:
self.__spike_data.append(sd)
else:
continue
self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))
def valid(self):
# fixme implement me!
pass
def __read_spike_data(self, r: RePro):
if self.__dataset.has_nix:
return self.__read_spike_data_from_nix(r)
else:
return self.__read_spike_data_from_directory(r)
def __read_eod_data(self, r: RePro, duration):
if self.__dataset.has_nix:
return self.__read_eod_data_from_nix(r, duration)
else:
return self.__read_eod_data_from_directory(r, duration)
def __get_serial_correlation(self, times, max_lags=50):
if times is None or len(times) < max_lags:
return None
isis = np.diff(times)
unbiased = isis - np.mean(isis, 0)
norm = sum(unbiased ** 2)
a_corr = np.correlate(unbiased, unbiased, "same") / norm
a_corr = a_corr[int(len(a_corr) / 2):]
return a_corr[:max_lags]
def serial_correlation(self, max_lags=50):
"""
Returns the serial correlation of the interspike intervals.
@param max_lags: The number of lags to take into account
@return: the serial correlation as a function of the lag
"""
scs = []
for sd in self.__spike_data:
if sd is None or len(sd) < 100:
continue
corr = self.__get_serial_correlation(sd, max_lags=max_lags)
if corr is not None:
scs.append(corr)
return scs
def circular_std(self):
circular_stds = []
for i in range(self.size):
phases = self.__spike_phases(index=i)
circular_stds.append(circstd(phases))
return circular_stds
def eod_frequency(self):
eod_frequencies = []
for i in range(self.size):
eod, time = self.eod(i)
xings = _zero_crossings(eod, time, interpolate=False)
eod_frequencies.append(len(xings)/xings[-1])
return 0.0 if len(eod_frequencies) < 1 else np.mean(eod_frequencies)
def __spike_phases(self, index=0): # fixme buffer this stuff
e_times = self.eod_times(index=index)
eod_period = np.mean(np.diff(e_times))
phases = np.zeros(len(self.spikes(index)))
for i, st in enumerate(self.spikes(index)):
last_eod_index = np.where(e_times <= st)[0]
if len(last_eod_index) == 0:
continue
phases[i] = (st - e_times[last_eod_index[-1]]) / eod_period * 2 * np.pi
return phases
def eod_times(self, index=0, interpolate=True):
if index >= self.size:
return None
if len(self.__eod_times) < len(self.__eod_data):
eod, time = self.eod(index)
times = _zero_crossings(eod, time, interpolate=interpolate)
else:
times = self.__eod_times[index]
return times
@property
def dataset(self):
return self.__dataset
@property
def cell(self):
cells = self.__dataset.cells
return cells if len(cells) > 1 else cells[0]
@property
def subject(self):
subjects = self.__dataset.subjects
return subjects if len(subjects) > 1 else subjects[0]
def spikes(self, index: int=0):
"""Get the spike times of the spikes recorded in the given baseline recording.
Args:
index (int, optional): If the baseline activity has been recorded several times, the index can be given. Defaults to 0.
Returns:
: [description]
"""
return self.__spike_data[index] if len(self.__spike_data) >= index else None
def membrane_voltage(self, index: int=0):
if index >= self.size:
raise IndexError("Index %i out of bounds for size %i!" % (index, self.size))
if not self.__dataset.has_nix:
print('Sorry, this is not supported for non-nixed datasets. Implement it at '
'fishbook.reproclasses.BaselineData.membrane_voltage and send a pull request!')
return None, None
else:
rp = self.__repros[index]
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
t = b.tags[rp.id]
if not t:
print("Tag not found!")
try:
data = t.retrieve_data("V-1")[:]
time = np.asarray(t.references["V-1"].dimensions[0].axis(len(data)))
except:
data = np.empty()
time = np.empty()
f.close()
return time, data
def eod(self, index: int=0):
eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
time = np.arange(len(eod)) / self.__dataset.samplerate
return eod, time
@property
def burst_index(self):
bi = []
for i, sd in enumerate(self.__spike_data):
if len(sd) < 2:
continue
et = self.eod_times(index=i)
eod_period = np.mean(np.diff(et))
isis = np.diff(sd)
bi.append(np.sum(isis < (1.5 * eod_period))/len(isis))
return bi
@property
def coefficient_of_variation(self):
cvs = []
for d in self.__spike_data:
isis = np.diff(d)
cvs.append(np.std(isis)/np.mean(isis))
return cvs
@property
def vector_strength(self):
vss = []
spike_phases = []
for i, sd in enumerate(self.__spike_data):
phases = self.__spike_phases(i)
ms_sin_alpha = np.mean(np.sin(phases)) ** 2
ms_cos_alpha = np.mean(np.cos(phases)) ** 2
vs = np.sqrt(ms_cos_alpha + ms_sin_alpha)
vss.append(vs)
spike_phases.append(phases)
return vss, spike_phases
@property
def size(self):
return len(self.__spike_data)
def __str__(self):
return "Baseline data of cell %s " % self.__cell.id
def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_eod_data_from_directory(r, duration)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
t = b.tags[r.id]
if not t:
print("Tag not found!")
try:
data = t.retrieve_data("EOD")[:]
except:
data = np.empty()
f.close()
return data
def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray:
sr = self.__dataset.samplerate
_unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
eod = eod[:int(duration * sr)]
return eod
def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_spike_data_from_directory(r)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
t = b.tags[r.id]
if not t:
print("Tag not found!")
try:
data = t.retrieve_data("Spikes-1")[:]
except:
data = None
f.close()
if len(data) < 100:
data = None
return data
def __read_spike_data_from_directory(self, r) -> np.ndarray:
data = []
data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
if os.path.exists(data_source):
found_run = False
with open(data_source, 'r') as f:
l = f.readline()
while l:
if "index" in l:
index = int(l.strip("#").strip().split(":")[-1])
found_run = index == r.run
if l.startswith("#Key") and found_run:
data = self.__do_read(f)
break
l = f.readline()
if len(data) < 100:
return None
return np.asarray(data)
@staticmethod
def __do_read(f) -> np.ndarray:
data = []
f.readline()
unit = f.readline().strip("#").strip()
scale = 0.001 if unit == "ms" else 1
l = f.readline()
while l and "#" not in l and len(l.strip()) > 0:
data.append(float(l.strip())*scale)
l = f.readline()
return np.asarray(data)
class BoltzmannFit:
"""
Class representing a fit of a Boltzmann function to some data.
"""
def __init__(self, xvalues: np.ndarray, yvalues: np.ndarray, initial_params=None):
"""
Constructor. Takes the x and the y data and tries to fit a Boltzmann to it.
:param xvalues: numpy array of x (e.g. contrast) values
:param yvalues: numpy array of y (e.g. firing rate) values
:param initial_params: list of initial parameters, default None to autogenerate
"""
assert(len(xvalues) == len(yvalues))
self.__xvals = xvalues
self.__yvals = yvalues
self.__fit_params = None
self.__initial_params = initial_params
self.__x_sorted = np.unique(self.__xvals)
self.__y_avg = None
self.__y_err = None
self.__do_fit()
@staticmethod
def boltzmann(x, y_max, slope, inflection):
"""
The underlying Boltzmann function.
.. math::
f(x) = y_max / \exp{-slope*(x-inflection}
:param x: The x values.
:param y_max: The maximum value.
:param slope: The slope parameter k
:param inflection: the position of the inflection point.
:return: the y values.
"""
y = y_max / (1 + np.exp(-slope * (x - inflection)))
return y
def __do_fit(self):
self.__y_avg = np.zeros(self.__x_sorted.shape)
self.__y_err = np.zeros(self.__x_sorted.shape)
for i, c in enumerate(self.__x_sorted):
self.__y_avg[i] = np.mean(self.__yvals[self.__xvals == c])
self.__y_err[i] = np.std(self.__yvals[self.__xvals == c])
if self.__initial_params:
p = self.__initial_params
else:
p = [np.max(self.__y_avg), 0, 0]
self.__fit_params, _ = curve_fit(self.boltzmann, self.__x_sorted, self.__y_avg, p)
@property
def slope(self) -> float:
r"""
The slope of the linear part of the Boltzmann, i.e.
.. math::
s = f_max $\cdot$ k / 4
:return: the slope.
"""
return self.__fit_params[0] * self.__fit_params[1] / 4
@property
def parameters(self):
""" fit parameters
:return: The fit parameters.
"""
return self.__fit_params
@property
def x_data(self):
""" The x data sorted and unique used for fitting.
:return: the x data
"""
return self.__x_sorted
@property
def y_data(self):
"""
the Y data used for fitting, i.e. the average rate in the specified time window sorted by the x data.
:return: the average and the standard deviation of the y data
"""
return self.__y_avg, self.__y_err
def solve(self, xvalues=None):
if not xvalues:
xvalues = self.__x_sorted
return self.boltzmann(xvalues, *self.__fit_params)
class FIData:
"""
Class representing the data recorded with the relacs FI-Curve repro. The instance will load the data upon
construction which may take a while.
FI Data offers convenient access to the spike and local EOD data as well as offers conveince methods to get the
firing rate and also to fit a Boltzmann function to the the FI curve.
"""
def __init__(self, dataset: Dataset):
"""
Constructor.
:param dataset: The dataset entity for which the fi curve repro data should be loaded.
"""
self.__spike_data = []
self.__contrasts = []
self.__eod_data = []
self.__eod_times = []
self.__dataset = dataset
self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
self._get_data()
pass
def _get_data(self):
if not self.__dataset:
return
self.__repros = RePro.find("FICurve", cell_id=self.__cell.id)
for r in self.__repros:
sd, c, eods, time = self.__read_spike_data(r)
if sd is not None and len(sd) > 1:
self.__spike_data.extend(sd)
if eods:
self.__eod_data.extend(eods)
self.__contrasts.extend(c)
if time:
self.__eod_times.extend(time)
else:
continue
def __read_spike_data(self, repro: RePro):
"""
:param repro:
:return: spike data and the respective contrasts
"""
if self.__dataset.has_nix:
return self.__read_spikes_from_nix(repro)
else:
return self.__read_spikes_from_directory(repro)
def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
r_settings = repro.settings.split("\n")
s_settings = stimulus.settings.split("\n")
delay = 0.0
contrast = 0.0
for s in r_settings:
if "delay:" in s:
delay = float(s.split(":")[-1])
break
for s in s_settings:
if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s:
contrast = float(s.split(":")[-1])
break
start_time = stimulus.start_time - delay
end_time = stimulus.start_time + stimulus.duration
eod_da = mtag.references["LocalEOD-1"]
start_index_eod = eod_da.dimensions[0].index_of(start_time)
end_index_eod = eod_da.dimensions[0].index_of(end_time)
local_eod = eod_da[start_index_eod:end_index_eod]
spikes = self.__all_spikes[(self.__all_spikes >= start_time) & (self.__all_spikes < end_time)] - start_time - delay
time = np.asarray(eod_da.dimensions[0].axis(end_index_eod - start_index_eod)) - delay
return spikes, local_eod, time, contrast
def __read_spikes_from_nix(self, repro: RePro):
spikes = []
eods = []
time = []
contrasts = []
stimuli = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
if len(stimuli) == 0:
return spikes, contrasts, eods, time
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_spikes_from_directory(repro)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
self.__all_spikes = b.data_arrays["Spikes-1"][:]
mt = None
for i in tqdm(range(len(stimuli)), desc="Loading data"):
s = stimuli[i]
if not mt or mt.id != s.multi_tag_id:
mt = b.multi_tags[s.multi_tag_id]
sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)
spikes.append(sp)
eods.append(eod)
time.append(t)
contrasts.append(c)
f.close()
return spikes, contrasts, eods, time
def __do_read_data_block(self, f, l):
spikes = []
while len(l.strip()) > 0 and "#" not in l:
spikes.append(float(l.strip())/1000)
l = f.readline()
return spikes, l
def __read_spikes_from_directory(self, repro: RePro):
spikes = []
contrasts = []
times = []
print("Warning! Exact reconstruction of stimulus order not possible for old relacs files!")
data_source = os.path.join(self.__dataset.data_source, "fispikes1.dat")
delay = 0.0
pause = 0.0
for s in repro.settings.split(", "):
if "pause" in s:
s = s.split(":")[-1].strip()
pause = float(s[:-2])/1000 if "ms" in s else float(s[:-1])
if "delay" in s:
s = s.split(":")[-1].strip()
delay = float(s[:-2])/1000 if "ms" in s else float(s[:-1])
t_start = -delay
t_end = pause + delay
time = np.arange(t_start, t_end, 1./repro.dataset.samplerate)
if os.path.exists(data_source):
with open(data_source, 'r') as f:
line = f.readline()
fish_intensity = None
stim_intensity = None
while line:
line = line.strip().lower()
if "index" in line:
fish_intensity = None
stim_intensity = None
if "intensity = " in line:
if "true intensity = " in line:
fish_intensity = float(line.split("=")[-1].strip()[:-2])
elif "pre" not in line:
stim_intensity = float(line.split("=")[-1].strip()[:-2])
if len(line) > 0 and "#" not in line: # data line
sp, line = self.__do_read_data_block(f, line)
spikes.append(sp)
times.append(time)
contrasts.append((stim_intensity/fish_intensity-1)*100)
continue
line = f.readline()
return spikes, contrasts, None, times
@property
def size(self) -> int:
"""
The number of recorded trials
:return: An integer with the number of trials.
"""
return len(self.__spike_data)
def spikes(self, index=-1):
"""
The spike times recorded in the specified trial(s)
:param index: the index of the trial. Default of -1 indicates that all data should be returned.
:return:
"""
if 0 <= index < self.size:
return self.__spike_data[index]
else:
return self.__spike_data
def eod(self, index=-1):
"""
The local eod (including the stimulus) measurement of the selected trial(s).
:param index: the index of the trial. Default of -1 indicates that all data should be returned.
:return: Either two vectors representing time and the local eod or two lists of such vectors
"""
if len(self.__eod_data) == 0:
print("EOD data not available for old-style relacs data.")
return None, None
if 0 <= index < self.size:
return self.__eod_times[index], self.__eod_data[index]
else:
return self.__eod_times, self.__eod_data
def contrast(self, index=-1):
"""
The stimulus contrast used in the respective trial(s).
:param index: the index of the trial. Default of -1 indicates that all data should be returned.
:return: Either a single scalar representing the contrast, or a list of such scalars, one entry for each trial.
"""
if 0 <= index < self.size:
return self.__contrasts[index]
else:
return self.__contrasts
def time_axis(self, index=-1):
"""
Get the time axis of a single trial or a list of time-vectors for all trials.
:param index: the index of the trial. Default of -1 indicates that all data should be returned.
:return: Either a single vector representing time, or a list of such vectors, one for each trial.
"""
if 0 <= index < self.size:
return self.__eod_times[index]
else:
return self.__eod_times
def rate(self, index=0, kernel_width=0.005):
"""
Returns the firing rate for a single trial.
:param index: The index of the trial. 0 <= index < size
:param kernel_width: The width of the gaussian kernel in seconds
:return: tuple of time and rate
"""
t = self.time_axis(index)
dt = np.mean(np.diff(t))
sp = self.spikes(index)
binary = np.zeros(t.shape)
spike_indices = ((sp - t[0]) / dt).astype(int)
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
g = gaussian_kernel(kernel_width, dt)
rate = np.convolve(binary, g, mode='same')
return t, rate
def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005):
"""
Extracts the average firing rate within a time window from the averaged across trial firing rate.
The analysis time window is specified by the start_time and end_time parameters. Firing rate is estimated by
convolution with a Gaussian kernel of a given width. All parameters are given in 's'.
:param start_time: the start of the analysis window.
:param end_time: the end of the analysis window.
:param kernel_width: standard deviation of the Gaussian kernel used for firing rate estimation.
:return: object of type BoltzmannFit
"""
contrasts = np.zeros(self.size)
rates = np.zeros(self.size)
for i in range(self.size):
contrasts[i] = np.round(self.contrast(i))
t, r = self.rate(i, kernel_width)
rates[i] = np.mean(r[(t >= start_time) & (t < end_time)])
boltzmann_fit = BoltzmannFit(contrasts, rates)
return boltzmann_fit
class FileStimulusData:
def __init__(self, dataset: Dataset):
"""
Constructor.
:param dataset: The dataset entity for which the filestimulus repro data should be loaded.
"""
self.__spike_data = []
self.__contrasts = []
self.__stimuli = []
self.__dataset = dataset
self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
self._get_data()
def _get_data(self):
if not self.__dataset:
return
self.__repros = RePro.find("FileStimulus", cell_id=self.__cell.id)
for r in self.__repros:
sd, c, stims = self.__read_spike_data_from_nix(r) if self.__dataset.has_nix else self.__read_spike_data_from_directory(r)
if sd is not None and len(sd) > 1:
self.__spike_data.extend(sd)
self.__contrasts.extend(c)
self.__stimuli.extend(stims)
else:
continue
def __do_read_spike_data_from_nix(self, mt: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
spikes = None
contrast = 0.0
stim_file = ""
r_settings = repro.settings.split("\n")
s_settings = stimulus.settings.split("\n")
delay = 0.0
for s in r_settings:
if "delay:" in s:
delay = float(s.split(":")[-1])
break
start_time = stimulus.start_time - delay
end_time = stimulus.start_time + mt.extents[stimulus.index]
contrast = 0.0 # this is a quick fix!!!
embed()
for s in s_settings:
if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s:
contrast = float(s.split(":")[-1])
break
return spikes, contrast, stim_file
local_eod = eod_da[start_index_eod:end_index_eod]
spikes = self.__all_spikes[(self.__all_spikes >= start_time) & (self.__all_spikes < end_time)] - start_time - delay
time = np.asarray(eod_da.dimensions[0].axis(end_index_eod - start_index_eod)) - delay
return spikes, local_eod, time, contrast
return spikes, contrast, stim_file
def __read_spike_data_from_nix(self, repro: RePro):
spikes = []
contrasts = []
stim_files = []
stimuli = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
if len(stimuli) == 0:
return spikes, contrasts, stim_files
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_spike_data_from_directory(repro)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
self.__all_spikes = b.data_arrays["Spikes-1"][:]
mt = None
for i in tqdm(range(len(stimuli)), desc="Loading data"):
s = stimuli[i]
if not mt or mt.id != s.multi_tag_id:
mt = b.multi_tags[s.multi_tag_id]
sp, c, stim = self.__do_read_spike_data_from_nix(mt, s, repro)
spikes.append(sp)
contrasts.append(c)
stim_files.append(stim)
f.close()
return spikes, contrasts, stim_files
def __read_spike_data_from_directory(self, repro: RePro):
print("not yet my friend!")
spikes = []
contrast = 0.0
stim = None
return spikes, contrast, stim
def read_stimulus(self, index=0):
pass
if __name__ == "__main__":
# dataset = Dataset(dataset_id='2011-06-14-ag')
dataset = Dataset(dataset_id="2018-09-13-ac-invivo-1")
# dataset = Dataset(dataset_id='2013-04-18-ac')
fi_curve = FileStimulusData(dataset)
embed()