[reproclasses] new Boltzmann fit class

This commit is contained in:
Jan Grewe 2019-10-07 14:55:13 +02:00
parent 8d253c9c7a
commit 7ab965fe37

View File

@ -4,6 +4,7 @@ import nixio as nix
from scipy.stats import circstd from scipy.stats import circstd
import os import os
import subprocess import subprocess
from tqdm import tqdm
from IPython import embed from IPython import embed
@ -38,9 +39,15 @@ def _unzip_if_needed(dataset, tracename='trace-1.raw'):
subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")]) subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])
def gaussian_kernel(sigma, dt):
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
class BaselineData: class BaselineData:
def __init__(self, dataset:Dataset): def __init__(self, dataset: Dataset):
self.__spike_data = [] self.__spike_data = []
self.__eod_data = [] self.__eod_data = []
self.__eod_times = [] self.__eod_times = []
@ -65,13 +72,13 @@ class BaselineData:
# fixme implement me! # fixme implement me!
pass pass
def __read_spike_data(self, r:RePro): def __read_spike_data(self, r: RePro):
if self.__dataset.has_nix: if self.__dataset.has_nix:
return self.__read_spike_data_from_nix(r) return self.__read_spike_data_from_nix(r)
else: else:
return self.__read_spike_data_from_directory(r) return self.__read_spike_data_from_directory(r)
def __read_eod_data(self, r:RePro, duration): def __read_eod_data(self, r: RePro, duration):
if self.__dataset.has_nix: if self.__dataset.has_nix:
return self.__read_eod_data_from_nix(r, duration) return self.__read_eod_data_from_nix(r, duration)
else: else:
@ -152,10 +159,10 @@ class BaselineData:
subjects = self.__dataset.subjects subjects = self.__dataset.subjects
return subjects if len(subjects) > 1 else subjects[0] return subjects if len(subjects) > 1 else subjects[0]
def spikes(self, index:int=0): def spikes(self, index: int=0):
return self.__spike_data[index] if len(self.__spike_data) >= index else None return self.__spike_data[index] if len(self.__spike_data) >= index else None
def eod(self, index:int=0): def eod(self, index: int=0):
eod = self.__eod_data[index] if len(self.__eod_data) >= index else None eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
time = np.arange(len(eod)) / self.__dataset.samplerate time = np.arange(len(eod)) / self.__dataset.samplerate
return eod, time return eod, time
@ -200,7 +207,7 @@ class BaselineData:
def __str__(self): def __str__(self):
str = "Baseline data of cell %s " % self.__cell.id str = "Baseline data of cell %s " % self.__cell.id
def __read_eod_data_from_nix(self, r:RePro, duration)->np.ndarray: def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix") data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source): if not os.path.exists(data_source):
print("Data not found! Trying from directory") print("Data not found! Trying from directory")
@ -217,14 +224,14 @@ class BaselineData:
f.close() f.close()
return data return data
def __read_eod_data_from_directory(self, r:RePro, duration)->np.ndarray: def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray:
sr = self.__dataset.samplerate sr = self.__dataset.samplerate
_unzip_if_needed(self.__dataset.data_source, "trace-2.raw") _unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32) eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
eod = eod[:int(duration * sr)] eod = eod[:int(duration * sr)]
return eod return eod
def __read_spike_data_from_nix(self, r:RePro)->np.ndarray: def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix") data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source): if not os.path.exists(data_source):
print("Data not found! Trying from directory") print("Data not found! Trying from directory")
@ -244,8 +251,7 @@ class BaselineData:
data = None data = None
return data return data
def __read_spike_data_from_directory(self, r) -> np.ndarray:
def __read_spike_data_from_directory(self, r)->np.ndarray:
data = [] data = []
data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat") data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
if os.path.exists(data_source): if os.path.exists(data_source):
@ -264,7 +270,8 @@ class BaselineData:
return None return None
return np.asarray(data) return np.asarray(data)
def __do_read(self, f)->np.ndarray: @staticmethod
def __do_read(f) -> np.ndarray:
data = [] data = []
f.readline() f.readline()
unit = f.readline().strip("#").strip() unit = f.readline().strip("#").strip()
@ -277,7 +284,7 @@ class BaselineData:
class FIData: class FIData:
def __init__(self, dataset:Dataset): def __init__(self, dataset: Dataset):
self.__spike_data = [] self.__spike_data = []
self.__contrasts = [] self.__contrasts = []
self.__eod_data = [] self.__eod_data = []
@ -302,7 +309,7 @@ class FIData:
else: else:
continue continue
def __read_spike_data(self, repro:RePro): def __read_spike_data(self, repro: RePro):
""" """
:param repro: :param repro:
@ -315,7 +322,7 @@ class FIData:
pass pass
return None, None, None, None return None, None, None, None
def __do_read_spike_data_from_nix(self, mtag:nix.pycore.MultiTag, stimulus:Stimulus, repro: RePro): def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
r_settings = repro.settings.split("\n") r_settings = repro.settings.split("\n")
s_settings = stimulus.settings.split("\n") s_settings = stimulus.settings.split("\n")
delay = 0.0 delay = 0.0
@ -344,7 +351,7 @@ class FIData:
return spikes, local_eod, time, contrast return spikes, local_eod, time, contrast
def __read_spikes_from_nix(self, repro:RePro): def __read_spikes_from_nix(self, repro: RePro):
spikes = [] spikes = []
eods = [] eods = []
time = [] time = []
@ -360,7 +367,8 @@ class FIData:
b = f.blocks[0] b = f.blocks[0]
mt = None mt = None
for s in stimuli: for i in tqdm(range(len(stimuli))):
s = stimuli[i]
if not mt or mt.id != s.multi_tag_id: if not mt or mt.id != s.multi_tag_id:
mt = b.multi_tags[s.multi_tag_id] mt = b.multi_tags[s.multi_tag_id]
sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro) sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)
@ -370,30 +378,74 @@ class FIData:
time.append(t) time.append(t)
contrasts.append(c) contrasts.append(c)
f.close() f.close()
return spikes, contrasts, eods, time return spikes, contrasts, eods, time
@property @property
def size(self): def size(self):
return len(self.__spike_data) return len(self.__spike_data)
def spikes(self, index=-1): def spikes(self, index=-1):
if 0 < index < self.size: if 0 <= index < self.size:
return self.__spike_data[index] return self.__spike_data[index]
else: else:
return self.__spike_data return self.__spike_data
def eod(self, index=-1): def eod(self, index=-1):
if 0 < index < self.size: if 0 <= index < self.size:
return self.__eod_times[index], self.__eod_data[index] return self.__eod_times[index], self.__eod_data[index]
else: else:
return self.__eod_times, self.__eod_data return self.__eod_times, self.__eod_data
def contrast(self, index=-1): def contrast(self, index=-1):
if 0 < index < self.size: if 0 <= index < self.size:
return self.__contrasts[index] return self.__contrasts[index]
else: else:
return self.__contrasts return self.__contrasts
def time_axis(self, index=-1):
if 0 <= index < self.size:
return self.__eod_times[index]
else:
return self.__eod_times
def rate(self, index=0, kernel_width=0.005):
"""
Returns the firing rate for a single trial.
:param index: The index of the trial. 0 <= index < size
:param kernel_width: The width of the gaussian kernel in seconds
:return: tuple of time and rate
"""
t = self.time_axis(index)
dt = np.mean(np.diff(t))
sp = self.spikes(index)
binary = np.zeros(t.shape)
spike_indices = (sp / dt).astype(int)
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
g = gaussian_kernel(kernel_width, dt)
rate = np.convolve(binary, g, mode='same')
return t, rate
def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005):
"""
:param start_time:
:param end_time:
:param kernel_width:
:return: object of type BoltzmannFit
"""
contrasts = np.zeros(self.size)
rates = np.zeros(self.size)
for i in range(self.size):
contrasts[i] = np.round(self.contrast(i))
t, r = self.rate(i, kernel_width)
rates[i] = np.mean(r[(t >= start_time) & (t < end_time)])
boltzmann_fit = BoltzmannFit(contrasts, rates)
return boltzmann_fit
if __name__ == "__main__": if __name__ == "__main__":
#dataset = Dataset(dataset_id='2011-06-14-ag') #dataset = Dataset(dataset_id='2011-06-14-ag')
dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1") dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1")