[reproclasses] new Boltzmann fit class

This commit is contained in:
Jan Grewe 2019-10-07 14:55:13 +02:00
parent 8d253c9c7a
commit 7ab965fe37

View File

@ -4,6 +4,7 @@ import nixio as nix
from scipy.stats import circstd
import os
import subprocess
from tqdm import tqdm
from IPython import embed
@ -38,9 +39,15 @@ def _unzip_if_needed(dataset, tracename='trace-1.raw'):
subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])
def gaussian_kernel(sigma, dt):
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
class BaselineData:
def __init__(self, dataset:Dataset):
def __init__(self, dataset: Dataset):
self.__spike_data = []
self.__eod_data = []
self.__eod_times = []
@ -65,13 +72,13 @@ class BaselineData:
# fixme implement me!
pass
def __read_spike_data(self, r:RePro):
def __read_spike_data(self, r: RePro):
if self.__dataset.has_nix:
return self.__read_spike_data_from_nix(r)
else:
return self.__read_spike_data_from_directory(r)
def __read_eod_data(self, r:RePro, duration):
def __read_eod_data(self, r: RePro, duration):
if self.__dataset.has_nix:
return self.__read_eod_data_from_nix(r, duration)
else:
@ -152,10 +159,10 @@ class BaselineData:
subjects = self.__dataset.subjects
return subjects if len(subjects) > 1 else subjects[0]
def spikes(self, index:int=0):
def spikes(self, index: int=0):
return self.__spike_data[index] if len(self.__spike_data) >= index else None
def eod(self, index:int=0):
def eod(self, index: int=0):
eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
time = np.arange(len(eod)) / self.__dataset.samplerate
return eod, time
@ -200,7 +207,7 @@ class BaselineData:
def __str__(self):
str = "Baseline data of cell %s " % self.__cell.id
def __read_eod_data_from_nix(self, r:RePro, duration)->np.ndarray:
def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
@ -217,14 +224,14 @@ class BaselineData:
f.close()
return data
def __read_eod_data_from_directory(self, r:RePro, duration)->np.ndarray:
def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray:
sr = self.__dataset.samplerate
_unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
eod = eod[:int(duration * sr)]
return eod
def __read_spike_data_from_nix(self, r:RePro)->np.ndarray:
def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray:
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
if not os.path.exists(data_source):
print("Data not found! Trying from directory")
@ -244,8 +251,7 @@ class BaselineData:
data = None
return data
def __read_spike_data_from_directory(self, r)->np.ndarray:
def __read_spike_data_from_directory(self, r) -> np.ndarray:
data = []
data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
if os.path.exists(data_source):
@ -264,7 +270,8 @@ class BaselineData:
return None
return np.asarray(data)
def __do_read(self, f)->np.ndarray:
@staticmethod
def __do_read(f) -> np.ndarray:
data = []
f.readline()
unit = f.readline().strip("#").strip()
@ -277,7 +284,7 @@ class BaselineData:
class FIData:
def __init__(self, dataset:Dataset):
def __init__(self, dataset: Dataset):
self.__spike_data = []
self.__contrasts = []
self.__eod_data = []
@ -302,7 +309,7 @@ class FIData:
else:
continue
def __read_spike_data(self, repro:RePro):
def __read_spike_data(self, repro: RePro):
"""
:param repro:
@ -315,7 +322,7 @@ class FIData:
pass
return None, None, None, None
def __do_read_spike_data_from_nix(self, mtag:nix.pycore.MultiTag, stimulus:Stimulus, repro: RePro):
def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
r_settings = repro.settings.split("\n")
s_settings = stimulus.settings.split("\n")
delay = 0.0
@ -344,7 +351,7 @@ class FIData:
return spikes, local_eod, time, contrast
def __read_spikes_from_nix(self, repro:RePro):
def __read_spikes_from_nix(self, repro: RePro):
spikes = []
eods = []
time = []
@ -360,7 +367,8 @@ class FIData:
b = f.blocks[0]
mt = None
for s in stimuli:
for i in tqdm(range(len(stimuli))):
s = stimuli[i]
if not mt or mt.id != s.multi_tag_id:
mt = b.multi_tags[s.multi_tag_id]
sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)
@ -370,30 +378,74 @@ class FIData:
time.append(t)
contrasts.append(c)
f.close()
return spikes, contrasts, eods, time
return spikes, contrasts, eods, time
@property
def size(self):
return len(self.__spike_data)
def spikes(self, index=-1):
if 0 < index < self.size:
if 0 <= index < self.size:
return self.__spike_data[index]
else:
return self.__spike_data
def eod(self, index=-1):
if 0 < index < self.size:
if 0 <= index < self.size:
return self.__eod_times[index], self.__eod_data[index]
else:
return self.__eod_times, self.__eod_data
def contrast(self, index=-1):
if 0 < index < self.size:
if 0 <= index < self.size:
return self.__contrasts[index]
else:
return self.__contrasts
def time_axis(self, index=-1):
if 0 <= index < self.size:
return self.__eod_times[index]
else:
return self.__eod_times
def rate(self, index=0, kernel_width=0.005):
"""
Returns the firing rate for a single trial.
:param index: The index of the trial. 0 <= index < size
:param kernel_width: The width of the gaussian kernel in seconds
:return: tuple of time and rate
"""
t = self.time_axis(index)
dt = np.mean(np.diff(t))
sp = self.spikes(index)
binary = np.zeros(t.shape)
spike_indices = (sp / dt).astype(int)
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
g = gaussian_kernel(kernel_width, dt)
rate = np.convolve(binary, g, mode='same')
return t, rate
def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005):
"""
:param start_time:
:param end_time:
:param kernel_width:
:return: object of type BoltzmannFit
"""
contrasts = np.zeros(self.size)
rates = np.zeros(self.size)
for i in range(self.size):
contrasts[i] = np.round(self.contrast(i))
t, r = self.rate(i, kernel_width)
rates[i] = np.mean(r[(t >= start_time) & (t < end_time)])
boltzmann_fit = BoltzmannFit(contrasts, rates)
return boltzmann_fit
if __name__ == "__main__":
#dataset = Dataset(dataset_id='2011-06-14-ag')
dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1")