[reproclasses] new Boltzmann fit class

This commit is contained in:
Jan Grewe 2019-10-07 14:55:13 +02:00
parent 8d253c9c7a
commit 7ab965fe37

View File

@ -4,6 +4,7 @@ import nixio as nix
from scipy.stats import circstd from scipy.stats import circstd
import os import os
import subprocess import subprocess
from tqdm import tqdm
from IPython import embed from IPython import embed
@ -38,6 +39,12 @@ def _unzip_if_needed(dataset, tracename='trace-1.raw'):
subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")]) subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])
def gaussian_kernel(sigma, dt):
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
class BaselineData: class BaselineData:
def __init__(self, dataset: Dataset): def __init__(self, dataset: Dataset):
@ -244,7 +251,6 @@ class BaselineData:
data = None data = None
return data return data
def __read_spike_data_from_directory(self, r) -> np.ndarray: def __read_spike_data_from_directory(self, r) -> np.ndarray:
data = [] data = []
data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat") data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
@ -264,7 +270,8 @@ class BaselineData:
return None return None
return np.asarray(data) return np.asarray(data)
def __do_read(self, f)->np.ndarray: @staticmethod
def __do_read(f) -> np.ndarray:
data = [] data = []
f.readline() f.readline()
unit = f.readline().strip("#").strip() unit = f.readline().strip("#").strip()
@ -360,7 +367,8 @@ class FIData:
b = f.blocks[0] b = f.blocks[0]
mt = None mt = None
for s in stimuli: for i in tqdm(range(len(stimuli))):
s = stimuli[i]
if not mt or mt.id != s.multi_tag_id: if not mt or mt.id != s.multi_tag_id:
mt = b.multi_tags[s.multi_tag_id] mt = b.multi_tags[s.multi_tag_id]
sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro) sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)
@ -377,23 +385,67 @@ class FIData:
return len(self.__spike_data) return len(self.__spike_data)
def spikes(self, index=-1): def spikes(self, index=-1):
if 0 < index < self.size: if 0 <= index < self.size:
return self.__spike_data[index] return self.__spike_data[index]
else: else:
return self.__spike_data return self.__spike_data
def eod(self, index=-1): def eod(self, index=-1):
if 0 < index < self.size: if 0 <= index < self.size:
return self.__eod_times[index], self.__eod_data[index] return self.__eod_times[index], self.__eod_data[index]
else: else:
return self.__eod_times, self.__eod_data return self.__eod_times, self.__eod_data
def contrast(self, index=-1): def contrast(self, index=-1):
if 0 < index < self.size: if 0 <= index < self.size:
return self.__contrasts[index] return self.__contrasts[index]
else: else:
return self.__contrasts return self.__contrasts
def time_axis(self, index=-1):
if 0 <= index < self.size:
return self.__eod_times[index]
else:
return self.__eod_times
def rate(self, index=0, kernel_width=0.005):
"""
Returns the firing rate for a single trial.
:param index: The index of the trial. 0 <= index < size
:param kernel_width: The width of the gaussian kernel in seconds
:return: tuple of time and rate
"""
t = self.time_axis(index)
dt = np.mean(np.diff(t))
sp = self.spikes(index)
binary = np.zeros(t.shape)
spike_indices = (sp / dt).astype(int)
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
g = gaussian_kernel(kernel_width, dt)
rate = np.convolve(binary, g, mode='same')
return t, rate
def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005):
"""
:param start_time:
:param end_time:
:param kernel_width:
:return: object of type BoltzmannFit
"""
contrasts = np.zeros(self.size)
rates = np.zeros(self.size)
for i in range(self.size):
contrasts[i] = np.round(self.contrast(i))
t, r = self.rate(i, kernel_width)
rates[i] = np.mean(r[(t >= start_time) & (t < end_time)])
boltzmann_fit = BoltzmannFit(contrasts, rates)
return boltzmann_fit
if __name__ == "__main__": if __name__ == "__main__":
#dataset = Dataset(dataset_id='2011-06-14-ag') #dataset = Dataset(dataset_id='2011-06-14-ag')
dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1") dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1")