forked from jgrewe/fishbook
[reproclasses] new Boltzmann fit class
This commit is contained in:
parent
8d253c9c7a
commit
7ab965fe37
@ -4,6 +4,7 @@ import nixio as nix
|
||||
from scipy.stats import circstd
|
||||
import os
|
||||
import subprocess
|
||||
from tqdm import tqdm
|
||||
|
||||
from IPython import embed
|
||||
|
||||
@ -38,9 +39,15 @@ def _unzip_if_needed(dataset, tracename='trace-1.raw'):
|
||||
subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])
|
||||
|
||||
|
||||
def gaussian_kernel(sigma, dt):
|
||||
x = np.arange(-4. * sigma, 4. * sigma, dt)
|
||||
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
|
||||
return y
|
||||
|
||||
|
||||
class BaselineData:
|
||||
|
||||
def __init__(self, dataset:Dataset):
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.__spike_data = []
|
||||
self.__eod_data = []
|
||||
self.__eod_times = []
|
||||
@ -65,13 +72,13 @@ class BaselineData:
|
||||
# fixme implement me!
|
||||
pass
|
||||
|
||||
def __read_spike_data(self, r:RePro):
|
||||
def __read_spike_data(self, r: RePro):
|
||||
if self.__dataset.has_nix:
|
||||
return self.__read_spike_data_from_nix(r)
|
||||
else:
|
||||
return self.__read_spike_data_from_directory(r)
|
||||
|
||||
def __read_eod_data(self, r:RePro, duration):
|
||||
def __read_eod_data(self, r: RePro, duration):
|
||||
if self.__dataset.has_nix:
|
||||
return self.__read_eod_data_from_nix(r, duration)
|
||||
else:
|
||||
@ -152,10 +159,10 @@ class BaselineData:
|
||||
subjects = self.__dataset.subjects
|
||||
return subjects if len(subjects) > 1 else subjects[0]
|
||||
|
||||
def spikes(self, index:int=0):
|
||||
def spikes(self, index: int=0):
|
||||
return self.__spike_data[index] if len(self.__spike_data) >= index else None
|
||||
|
||||
def eod(self, index:int=0):
|
||||
def eod(self, index: int=0):
|
||||
eod = self.__eod_data[index] if len(self.__eod_data) >= index else None
|
||||
time = np.arange(len(eod)) / self.__dataset.samplerate
|
||||
return eod, time
|
||||
@ -200,7 +207,7 @@ class BaselineData:
|
||||
def __str__(self):
|
||||
str = "Baseline data of cell %s " % self.__cell.id
|
||||
|
||||
def __read_eod_data_from_nix(self, r:RePro, duration)->np.ndarray:
|
||||
def __read_eod_data_from_nix(self, r: RePro, duration) -> np.ndarray:
|
||||
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
|
||||
if not os.path.exists(data_source):
|
||||
print("Data not found! Trying from directory")
|
||||
@ -217,14 +224,14 @@ class BaselineData:
|
||||
f.close()
|
||||
return data
|
||||
|
||||
def __read_eod_data_from_directory(self, r:RePro, duration)->np.ndarray:
|
||||
def __read_eod_data_from_directory(self, r: RePro, duration) -> np.ndarray:
|
||||
sr = self.__dataset.samplerate
|
||||
_unzip_if_needed(self.__dataset.data_source, "trace-2.raw")
|
||||
eod = np.fromfile(self.__dataset.data_source + "/trace-2.raw", np.float32)
|
||||
eod = eod[:int(duration * sr)]
|
||||
return eod
|
||||
|
||||
def __read_spike_data_from_nix(self, r:RePro)->np.ndarray:
|
||||
def __read_spike_data_from_nix(self, r: RePro) -> np.ndarray:
|
||||
data_source = os.path.join(self.__dataset.data_source, self.__dataset.id + ".nix")
|
||||
if not os.path.exists(data_source):
|
||||
print("Data not found! Trying from directory")
|
||||
@ -244,8 +251,7 @@ class BaselineData:
|
||||
data = None
|
||||
return data
|
||||
|
||||
|
||||
def __read_spike_data_from_directory(self, r)->np.ndarray:
|
||||
def __read_spike_data_from_directory(self, r) -> np.ndarray:
|
||||
data = []
|
||||
data_source = os.path.join(self.__dataset.data_source, "basespikes1.dat")
|
||||
if os.path.exists(data_source):
|
||||
@ -264,7 +270,8 @@ class BaselineData:
|
||||
return None
|
||||
return np.asarray(data)
|
||||
|
||||
def __do_read(self, f)->np.ndarray:
|
||||
@staticmethod
|
||||
def __do_read(f) -> np.ndarray:
|
||||
data = []
|
||||
f.readline()
|
||||
unit = f.readline().strip("#").strip()
|
||||
@ -277,7 +284,7 @@ class BaselineData:
|
||||
|
||||
|
||||
class FIData:
|
||||
def __init__(self, dataset:Dataset):
|
||||
def __init__(self, dataset: Dataset):
|
||||
self.__spike_data = []
|
||||
self.__contrasts = []
|
||||
self.__eod_data = []
|
||||
@ -302,7 +309,7 @@ class FIData:
|
||||
else:
|
||||
continue
|
||||
|
||||
def __read_spike_data(self, repro:RePro):
|
||||
def __read_spike_data(self, repro: RePro):
|
||||
"""
|
||||
|
||||
:param repro:
|
||||
@ -315,7 +322,7 @@ class FIData:
|
||||
pass
|
||||
return None, None, None, None
|
||||
|
||||
def __do_read_spike_data_from_nix(self, mtag:nix.pycore.MultiTag, stimulus:Stimulus, repro: RePro):
|
||||
def __do_read_spike_data_from_nix(self, mtag: nix.pycore.MultiTag, stimulus: Stimulus, repro: RePro):
|
||||
r_settings = repro.settings.split("\n")
|
||||
s_settings = stimulus.settings.split("\n")
|
||||
delay = 0.0
|
||||
@ -344,7 +351,7 @@ class FIData:
|
||||
|
||||
return spikes, local_eod, time, contrast
|
||||
|
||||
def __read_spikes_from_nix(self, repro:RePro):
|
||||
def __read_spikes_from_nix(self, repro: RePro):
|
||||
spikes = []
|
||||
eods = []
|
||||
time = []
|
||||
@ -360,7 +367,8 @@ class FIData:
|
||||
b = f.blocks[0]
|
||||
|
||||
mt = None
|
||||
for s in stimuli:
|
||||
for i in tqdm(range(len(stimuli))):
|
||||
s = stimuli[i]
|
||||
if not mt or mt.id != s.multi_tag_id:
|
||||
mt = b.multi_tags[s.multi_tag_id]
|
||||
sp, eod, t, c = self.__do_read_spike_data_from_nix(mt, s, repro)
|
||||
@ -370,30 +378,74 @@ class FIData:
|
||||
time.append(t)
|
||||
contrasts.append(c)
|
||||
f.close()
|
||||
return spikes, contrasts, eods, time
|
||||
return spikes, contrasts, eods, time
|
||||
|
||||
@property
|
||||
def size(self):
|
||||
return len(self.__spike_data)
|
||||
|
||||
def spikes(self, index=-1):
|
||||
if 0 < index < self.size:
|
||||
if 0 <= index < self.size:
|
||||
return self.__spike_data[index]
|
||||
else:
|
||||
return self.__spike_data
|
||||
|
||||
def eod(self, index=-1):
|
||||
if 0 < index < self.size:
|
||||
if 0 <= index < self.size:
|
||||
return self.__eod_times[index], self.__eod_data[index]
|
||||
else:
|
||||
return self.__eod_times, self.__eod_data
|
||||
|
||||
def contrast(self, index=-1):
|
||||
if 0 < index < self.size:
|
||||
if 0 <= index < self.size:
|
||||
return self.__contrasts[index]
|
||||
else:
|
||||
return self.__contrasts
|
||||
|
||||
def time_axis(self, index=-1):
|
||||
if 0 <= index < self.size:
|
||||
return self.__eod_times[index]
|
||||
else:
|
||||
return self.__eod_times
|
||||
|
||||
def rate(self, index=0, kernel_width=0.005):
|
||||
"""
|
||||
Returns the firing rate for a single trial.
|
||||
|
||||
:param index: The index of the trial. 0 <= index < size
|
||||
:param kernel_width: The width of the gaussian kernel in seconds
|
||||
:return: tuple of time and rate
|
||||
"""
|
||||
t = self.time_axis(index)
|
||||
dt = np.mean(np.diff(t))
|
||||
sp = self.spikes(index)
|
||||
binary = np.zeros(t.shape)
|
||||
spike_indices = (sp / dt).astype(int)
|
||||
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
|
||||
g = gaussian_kernel(kernel_width, dt)
|
||||
rate = np.convolve(binary, g, mode='same')
|
||||
return t, rate
|
||||
|
||||
def boltzmann_fit(self, start_time=0.01, end_time=0.05, kernel_width=0.005):
|
||||
"""
|
||||
|
||||
:param start_time:
|
||||
:param end_time:
|
||||
:param kernel_width:
|
||||
:return: object of type BoltzmannFit
|
||||
"""
|
||||
contrasts = np.zeros(self.size)
|
||||
rates = np.zeros(self.size)
|
||||
for i in range(self.size):
|
||||
contrasts[i] = np.round(self.contrast(i))
|
||||
t, r = self.rate(i, kernel_width)
|
||||
rates[i] = np.mean(r[(t >= start_time) & (t < end_time)])
|
||||
|
||||
boltzmann_fit = BoltzmannFit(contrasts, rates)
|
||||
return boltzmann_fit
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
#dataset = Dataset(dataset_id='2011-06-14-ag')
|
||||
dataset = Dataset(dataset_id="2018-01-19-ac-invivo-1")
|
||||
|
Loading…
Reference in New Issue
Block a user