fishBook/fishbook/frontend/relacs_classes.py

653 lines
24 KiB
Python

from .frontend_classes import Dataset, RePro, Stimulus
from .util import BoltzmannFit, unzip_if_needed, gaussian_kernel, zero_crossings
import numpy as np
import nixio as nix
from scipy.stats import circstd
# from scipy.optimize import curve_fit
import os
import subprocess
from tqdm import tqdm
from IPython import embed
class BaselineData:
"""
Class representing the Baseline data that has been recorded within a given Dataset.
"""
def __init__(self, dataset: Dataset):
self.__spike_data = []
self.__eod_data = []
self.__eod_times = []
self.__dataset = dataset
self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
self._get_data()
def _get_data(self):
if not self.__dataset:
return
self.__repros, _ = RePro.find("BaselineActivity", cell_id=self.__cell.id)
for i in tqdm(range(len(self.__repros)), desc="loading data"):
r = self.__repros[i]
sd = self.__read_spike_data(r)
if sd is not None and len(sd) > 1:
self.__spike_data.append(sd)
else:
continue
self.__eod_data.append(self.__read_eod_data(r, self.__spike_data[-1][-1]))
def valid(self):
# fixme implement me!
pass
def __read_spike_data(self, r: RePro):
if self.__dataset.has_nix:
return self.__read_spike_data_from_nix(r)
else:
return self.__read_spike_data_from_directory(r)
def __read_eod_data(self, r: RePro, duration):
if self.__dataset.has_nix:
return self.__read_eod_data_from_nix(r, duration)
else:
return self.__read_eod_data_from_directory(r, duration)
def __get_serial_correlation(self, times, max_lags=50):
if times is None or len(times) < max_lags:
return None
isis = np.diff(times)
unbiased = isis - np.mean(isis, 0)
norm = sum(unbiased ** 2)
a_corr = np.correlate(unbiased, unbiased, "same") / norm
a_corr = a_corr[int(len(a_corr) / 2):]
return a_corr[:max_lags]
def serial_correlation(self, max_lags=50):
"""
Returns the serial correlation of the interspike intervals.
@param max_lags: The number of lags to take into account
@return: the serial correlation as a function of the lag
"""
scs = []
for sd in self.__spike_data:
if sd is None or len(sd) < 100:
continue
corr = self.__get_serial_correlation(sd, max_lags=max_lags)
if corr is not None:
scs.append(corr)
return scs
def circular_std(self):
circular_stds = []
for i in range(self.size):
phases = self.__spike_phases(index=i)
circular_stds.append(circstd(phases))
return circular_stds
def eod_frequency(self):
eod_frequencies = []
for i in range(self.size):
eod, time = self.eod(i)
xings = zero_crossings(eod, time, interpolate=False)
eod_frequencies.append(len(xings)/xings[-1])
return 0.0 if len(eod_frequencies) < 1 else np.mean(eod_frequencies)
def __spike_phases(self, index=0): # fixme buffer this stuff
e_times = self.eod_times(index=index)
eod_period = np.mean(np.diff(e_times))
phases = np.zeros(len(self.spikes(index)))
for i, st in enumerate(self.spikes(index)):
last_eod_index = np.where(e_times <= st)[0]
if len(last_eod_index) == 0:
continue
phases[i] = (st - e_times[last_eod_index[-1]]) / eod_period * 2 * np.pi
return phases
def eod_times(self, index=0, interpolate=True):
if index >= self.size:
return None
if len(self.__eod_times) < len(self.__eod_data):
eod, time = self.eod(index)
times = zero_crossings(eod, time, interpolate=interpolate)
else:
times = self.__eod_times[index]
return times
@property
def dataset(self):
return self.__dataset
@property
def cell(self):
cells = self.__dataset.cells
return cells if len(cells) > 1 else cells[0]
@property
def subject(self):
subjects = self.__dataset.subjects
return subjects if len(subjects) > 1 else subjects[0]
def spikes(self, index: int=0):
"""Get the spike times of the spikes recorded in the given baseline recording.
Args:
index (int, optional): If the baseline activity has been recorded several times, the index can be given. Defaults to 0.
Returns:
: [description]
"""
return self.__spike_data[index] if len(self.__spike_data) >= index else None
def membrane_voltage(self, index: int=0):
if index >= self.size:
raise IndexError("Index %i out of bounds for size %i!" % (index, self.size))
if not self.__dataset.has_nix:
print('Sorry, this is not supported for non-nixed datasets. Implement it at '
'fishbook.reproclasses.BaselineData.membrane_voltage and send a pull request!')
return None, None
else:
rp = self.__repros[index]
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
t = b.tags[rp.id]
if not t:
print("Tag not found!")
try:
data = t.retrieve_data("V-1")[:]
time = np.asarray(t.references["V-1"].dimensions[0].axis(len(data)))
except:
data = np.empty()
time = np.empty()
f.close()
return time, data
def eod(self, index: int=0):
eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
time = np.arange(len(eod)) / self.__dataset.samplerate
return eod, time
@property
def burst_index(self):
bi = []
for i, sd in enumerate(self.__spike_data):
if len(sd) < 2:
continue
et = self.eod_times(index=i)
eod_period = np.mean(np.diff(et))
isis = np.diff(sd)
bi.append(np.sum(isis < (1.5 * eod_period))/len(isis))
return bi
@property
def coefficient_of_variation(self):
cvs = []
for d in self.__spike_data:
isis = np.diff(d)
cvs.append(np.std(isis)/np.mean(isis))
return cvs
@property
def vector_strength(self):
vss = []
spike_phases = []
for i, sd in enumerate(self.__spike_data):
phases = self.__spike_phases(i)
ms_sin_alpha = np.mean(np.sin(phases)) ** 2
ms_cos_alpha = np.mean(np.cos(phases)) ** 2
vs = np.sqrt(ms_cos_alpha + ms_sin_alpha)
vss.append(vs)
spike_phases.append(phases)
return vss, spike_phases
@property
def size(self):
return len(self.__spike_data)
def __str__(self):
return "Baseline data of cell %s " % self.__cell.id
def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_eod_data_from_directory(r, duration)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
t = b.tags[r.id]
if not t:
print("Tag not found!")
try:
data = t.retrieve_data("EOD")[:]
except:
data = np.empty()
f.close()
return data
def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray:
sr = self.__dataset.samplerate
unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
eod = eod[:int(duration * sr)]
return eod
def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_spike_data_from_directory(r)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
t = b.tags[r.id]
if not t:
print("Tag not found!")
try:
data = t.retrieve_data("Spikes-1")[:]
except:
data = None
f.close()
if len(data) < 100:
data = None
return data
def __read_spike_data_from_directory(self, r) -> np.ndarray:
data = []
data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
if os.path.exists(data_source):
found_run = False
with open(data_source, 'r') as f:
l = f.readline()
while l:
if "index" in l:
index = int(l.strip("#").strip().split(":")[-1])
found_run = index == r.run
if l.startswith("#Key") and found_run:
data = self.__do_read(f)
break
l = f.readline()
if len(data) < 100:
return None
return np.asarray(data)
@staticmethod
def __do_read(f) -> np.ndarray:
data = []
f.readline()
unit = f.readline().strip("#").strip()
scale = 0.001 if unit == "ms" else 1
l = f.readline()
while l and "#" not in l and len(l.strip()) > 0:
data.append(float(l.strip())*scale)
l = f.readline()
return np.asarray(data)
class FIData:
"""
Class representing the data recorded with the relacs FI-Curve repro. The instance will load the data upon
construction which may take a while.
FI Data offers convenient access to the spike and local EOD data as well as offers conveince methods to get the
firing rate and also to fit a Boltzmann function to the the FI curve.
"""
def __init__(self, dataset: Dataset):
"""
Constructor.
:param dataset: The dataset entity for which the fi curve repro data should be loaded.
"""
self.__spike_data = []
self.__contrasts = []
self.__eod_data = []
self.__eod_times = []
self.__dataset = dataset
self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
self._get_data()
pass
def _get_data(self):
if not self.__dataset:
return
self.__repros,_ = RePro.find("FICurve", cell_id=self.__cell.id)
for r in self.__repros:
sd, c, eods, time = self.__read_spike_data(r)
if sd is not None and len(sd) > 1:
self.__spike_data.extend(sd)
if eods:
self.__eod_data.extend(eods)
self.__contrasts.extend(c)
if time:
self.__eod_times.extend(time)
else:
continue
def __read_spike_data(self, repro: RePro):
"""
:param repro:
:return: spike data and the respective contrasts
"""
if self.__dataset.has_nix:
return self.__read_spikes_from_nix(repro)
else:
return self.__read_spikes_from_directory(repro)
def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
r_settings = repro.settings.split("\n")
s_settings = stimulus.settings.split("\n")
delay = 0.0
contrast = 0.0
for s in r_settings:
if "delay:" in s:
delay = float(s.split(":")[-1])
break
for s in s_settings:
if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s:
contrast = float(s.split(":")[-1])
break
start_time = stimulus.start_time - delay
end_time = stimulus.start_time + stimulus.duration
eod_da = mtag.references["LocalEOD-1"]
start_index_eod = eod_da.dimensions[0].index_of(start_time)
end_index_eod = eod_da.dimensions[0].index_of(end_time)
local_eod = eod_da[start_index_eod:end_index_eod]
spikes = self.__all_spikes[(self.__all_spikes >= start_time) & (self.__all_spikes < end_time)] - start_time - delay
time = np.asarray(eod_da.dimensions[0].axis(end_index_eod - start_index_eod)) - delay
return spikes, local_eod, time, contrast
def __read_spikes_from_nix(self, repro: RePro):
spikes = []
eods = []
time = []
contrasts = []
stimuli, count = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
if count == 0:
return spikes, contrasts, eods, time
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_spikes_from_directory(repro)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
self.__all_spikes = b.data_arrays["Spikes-1"][:]
mt = None
for i in tqdm(range(count), desc="Loading data"):
s = stimuli[i]
if not mt or mt.id != s.multi_tag_id:
mt = b.multi_tags[s.multi_tag_id]
sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)
spikes.append(sp)
eods.append(eod)
time.append(t)
contrasts.append(c)
f.close()
return spikes, contrasts, eods, time
def __do_read_data_block(self, f, l):
spikes = []
while len(l.strip()) > 0 and "#" not in l:
spikes.append(float(l.strip())/1000)
l = f.readline()
return spikes, l
def __read_spikes_from_directory(self, repro: RePro):
spikes = []
contrasts = []
times = []
print("Warning! Exact reconstruction of stimulus order not possible for old relacs files!")
data_source = os.path.join(self.__dataset.data_source, "fispikes1.dat")
delay = 0.0
pause = 0.0
for s in repro.settings.split(", "):
if "pause" in s:
s = s.split(":")[-1].strip()
pause = float(s[:-2])/1000 if "ms" in s else float(s[:-1])
if "delay" in s:
s = s.split(":")[-1].strip()
delay = float(s[:-2])/1000 if "ms" in s else float(s[:-1])
t_start = -delay
t_end = pause + delay
time = np.arange(t_start, t_end, 1./repro.dataset.samplerate)
if os.path.exists(data_source):
with open(data_source, 'r') as f:
line = f.readline()
fish_intensity = None
stim_intensity = None
while line:
line = line.strip().lower()
if "index" in line:
fish_intensity = None
stim_intensity = None
if "intensity = " in line:
if "true intensity = " in line:
fish_intensity = float(line.split("=")[-1].strip()[:-2])
elif "pre" not in line:
stim_intensity = float(line.split("=")[-1].strip()[:-2])
if len(line) > 0 and "#" not in line: # data line
sp, line = self.__do_read_data_block(f, line)
spikes.append(sp)
times.append(time)
contrasts.append((stim_intensity/fish_intensity-1)*100)
continue
line = f.readline()
return spikes, contrasts, None, times
@property
def size(self) -> int:
"""
The number of recorded trials
:return: An integer with the number of trials.
"""
return len(self.__spike_data)
def spikes(self, index=-1):
"""
The spike times recorded in the specified trial(s)
:param index: the index of the trial. Default of -1 indicates that all data should be returned.
:return:
"""
if 0 <= index < self.size:
return self.__spike_data[index]
else:
return self.__spike_data
def eod(self, index=-1):
"""
The local eod (including the stimulus) measurement of the selected trial(s).
:param index: the index of the trial. Default of -1 indicates that all data should be returned.
:return: Either two vectors representing time and the local eod or two lists of such vectors
"""
if len(self.__eod_data) == 0:
print("EOD data not available for old-style relacs data.")
return None, None
if 0 <= index < self.size:
return self.__eod_times[index], self.__eod_data[index]
else:
return self.__eod_times, self.__eod_data
def contrast(self, index=-1):
"""
The stimulus contrast used in the respective trial(s).
:param index: the index of the trial. Default of -1 indicates that all data should be returned.
:return: Either a single scalar representing the contrast, or a list of such scalars, one entry for each trial.
"""
if 0 <= index < self.size:
return self.__contrasts[index]
else:
return self.__contrasts
def time_axis(self, index=-1):
"""
Get the time axis of a single trial or a list of time-vectors for all trials.
:param index: the index of the trial. Default of -1 indicates that all data should be returned.
:return: Either a single vector representing time, or a list of such vectors, one for each trial.
"""
if 0 <= index < self.size:
return self.__eod_times[index]
else:
return self.__eod_times
def rate(self, index=0, kernel_width=0.005):
"""
Returns the firing rate for a single trial.
:param index: The index of the trial. 0 <= index < size
:param kernel_width: The width of the gaussian kernel in seconds
:return: tuple of time and rate
"""
t = self.time_axis(index)
dt = np.mean(np.diff(t))
sp = self.spikes(index)
binary = np.zeros(t.shape)
spike_indices = ((sp - t[0]) / dt).astype(int)
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
g = gaussian_kernel(kernel_width, dt)
rate = np.convolve(binary, g, mode='same')
return t, rate
def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005):
"""
Extracts the average firing rate within a time window from the averaged across trial firing rate.
The analysis time window is specified by the start_time and end_time parameters. Firing rate is estimated by
convolution with a Gaussian kernel of a given width. All parameters are given in 's'.
:param start_time: the start of the analysis window.
:param end_time: the end of the analysis window.
:param kernel_width: standard deviation of the Gaussian kernel used for firing rate estimation.
:return: object of type BoltzmannFit
"""
contrasts = np.zeros(self.size)
rates = np.zeros(self.size)
for i in range(self.size):
contrasts[i] = np.round(self.contrast(i))
t, r = self.rate(i, kernel_width)
rates[i] = np.mean(r[(t >= start_time) & (t < end_time)])
boltzmann_fit = BoltzmannFit(contrasts, rates)
return boltzmann_fit
class FileStimulusData:
"""The FileStimulus class provides access to the data recorded and the stimulus presented (if accessible)
during runs of the FileStimulus repro. Since the FileStimulus repro can put out any stimulus this class does not
provide any further analyses.
As any other relacs class it is instantiated with a Dataset entity.
"""
def __init__(self, dataset: Dataset):
"""
Constructor.
Args
fishbook.Dataset: The dataset entity for which the filestimulus repro data should be loaded.
"""
self.__spike_data = []
self.__contrasts = []
self.__stimuli = []
self.__dataset = dataset
self.__repros = None
self.__cell = dataset.cells[0] # Beware: Assumption that there is only a single cell
self._get_data()
def _get_data(self):
if not self.__dataset:
return
self.__repros, _ = RePro.find("FileStimulus", cell_id=self.__cell.id)
for r in self.__repros:
sd, c, stims = self.__read_spike_data_from_nix(r) if self.__dataset.has_nix else self.__read_spike_data_from_directory(r)
if sd is not None and len(sd) > 1:
self.__spike_data.extend(sd)
self.__contrasts.extend(c)
self.__stimuli.extend(stims)
else:
continue
def __do_read_spike_data_from_nix(self, mt: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
spikes = None
contrast = 0.0
stim_file = ""
r_settings = repro.settings.split("\n")
s_settings = stimulus.settings.split("\n")
delay = 0.0
for s in r_settings:
if "delay:" in s:
delay = float(s.split(":")[-1])
break
start_time = stimulus.start_time - delay
end_time = stimulus.start_time + mt.extents[stimulus.index]
contrast = 0.0 # this is a quick fix!!!
embed()
for s in s_settings:
if "Contrast:" in s and "PreContrast" not in s and "\t\t" not in s and "+-" not in s:
contrast = float(s.split(":")[-1])
break
return spikes, contrast, stim_file
local_eod = eod_da[start_index_eod:end_index_eod]
spikes = self.__all_spikes[(self.__all_spikes >= start_time) & (self.__all_spikes < end_time)] - start_time - delay
time = np.asarray(eod_da.dimensions[0].axis(end_index_eod - start_index_eod)) - delay
return spikes, local_eod, time, contrast
return spikes, contrast, stim_file
def __read_spike_data_from_nix(self, repro: RePro):
spikes = []
contrasts = []
stim_files = []
stimuli = Stimulus.find(cell_id=repro.cell_id, repro_id=repro.id)
if len(stimuli) == 0:
return spikes, contrasts, stim_files
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
return self.__read_spike_data_from_directory(repro)
f = nix.File.open(data_source, nix.FileMode.ReadOnly)
b = f.blocks[0]
self.__all_spikes = b.data_arrays["Spikes-1"][:]
mt = None
for i in tqdm(range(len(stimuli)), desc="Loading data"):
s = stimuli[i]
if not mt or mt.id != s.multi_tag_id:
mt = b.multi_tags[s.multi_tag_id]
sp, c, stim = self.__do_read_spike_data_from_nix(mt, s, repro)
spikes.append(sp)
contrasts.append(c)
stim_files.append(stim)
f.close()
return spikes, contrasts, stim_files
def __read_spike_data_from_directory(self, repro: RePro):
print("not yet my friend!")
spikes = []
contrast = 0.0
stim = None
return spikes, contrast, stim
def read_stimulus(self, index=0):
pass
if __name__ == "__main__":
# dataset = Dataset(dataset_id='2011-06-14-ag')
dataset = Dataset(dataset_id="2018-09-13-ac-invivo-1")
# dataset = Dataset(dataset_id='2013-04-18-ac')
fi_curve = FileStimulusData(dataset)
embed()