adapt for ModelFit

This commit is contained in:
a.ott 2020-07-23 10:35:06 +02:00
parent 6992698323
commit 2e6a4af7fd

View File

@ -7,25 +7,21 @@ import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import helperFunctions as hF import helperFunctions as hF
from CellData import CellData from CellData import CellData
from ModelFit import ModelFit, get_best_fit
import os
def main(): def main():
# 2012-07-12-ag-invivo-1 fit and eod frequency: sam_analysis("results/invivo_results/2013-01-08-ad-invivo-1/")
# parameters = {'refractory_period': 0.00080122694889117, 'v_base': 0, 'v_zero': 0, 'a_zero': 20, 'step_size': 5e-05, quit()
# 'delta_a': 0.23628384937392385, 'threshold': 1, 'input_scaling': 100.66894113671654, modelfit = get_best_fit("results/invivo_results/2013-01-08-ad-invivo-1/")
# 'mem_tau': 0.012388673630113763, 'tau_a': 0.09106579031822526, 'v_offset': -6.25,
# 'noise_strength': 0.0404417932620334, 'dend_tau': 0.00122153436141022}
# cell_data = CellData("./data/2012-07-12-ag-invivo-1/")
parameters = {'delta_a': 0.08820130374685671, 'refractory_period': 0.0006, 'a_zero': 15, 'step_size': 5e-05,
'v_base': 0, 'noise_strength': 0.03622523883042496, 'v_zero': 0, 'threshold': 1,
'input_scaling': 77.75367190909581, 'tau_a': 0.07623731247799118, 'v_offset': -10.546875,
'mem_tau': 0.008741976196676469, 'dend_tau': 0.0012058986118892773}
cell_data = CellData("./data/2012-12-13-an-invivo-1/") if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")):
print("Cell: {} \n Has no measured sam stimuli.")
return
cell_data = CellData(modelfit.get_cell_path())
eod_freq = cell_data.get_eod_frequency() eod_freq = cell_data.get_eod_frequency()
model = LifacNoiseModel(parameters) model = modelfit.get_model()
# base_cell = get_baseline_class(cell_data) # base_cell = get_baseline_class(cell_data)
# base_model = get_baseline_class(model, cell_data.get_eod_frequency()) # base_model = get_baseline_class(model, cell_data.get_eod_frequency())
@ -45,6 +41,7 @@ def main():
u_durations = np.unique(durations) u_durations = np.unique(durations)
mean_duration = np.mean(durations) mean_duration = np.mean(durations)
contrasts = cell_data.get_sam_contrasts() contrasts = cell_data.get_sam_contrasts()
u_contrasts = np.unique(contrasts)
contrast = contrasts[0] # are all the same in this test case contrast = contrasts[0] # are all the same in this test case
spiketimes = cell_data.get_sam_spiketimes() spiketimes = cell_data.get_sam_spiketimes()
delta_freqs = cell_data.get_sam_delta_frequencies() delta_freqs = cell_data.get_sam_delta_frequencies()
@ -66,6 +63,7 @@ def main():
# plt.plot(prob_density_function_model) # plt.plot(prob_density_function_model)
# plt.show() # plt.show()
# plt.close() # plt.close()
fig, axes = plt.subplots(1, 4) fig, axes = plt.subplots(1, 4)
cuts = cut_pdf_into_periods(prob_density_function_model, 1/float(m_freq), step_size) cuts = cut_pdf_into_periods(prob_density_function_model, 1/float(m_freq), step_size)
for c in cuts: for c in cuts:
@ -78,12 +76,15 @@ def main():
for spikes_cell in spikes_dictionary[m_freq]: for spikes_cell in spikes_dictionary[m_freq]:
prob_density_cell = spiketimes_calculate_pdf(spikes_cell[0], step_size) prob_density_cell = spiketimes_calculate_pdf(spikes_cell[0], step_size)
if len(prob_density_cell) < 3 * (eod_freq / step_size):
continue
cuts_cell = cut_pdf_into_periods(prob_density_cell, 1/float(m_freq), step_size) cuts_cell = cut_pdf_into_periods(prob_density_cell, 1/float(m_freq), step_size)
for c in cuts_cell: for c in cuts_cell:
axes[1].plot(c, color="gray", alpha=0.15) axes[1].plot(c, color="gray", alpha=0.15)
print(cuts_cell.shape) print(cuts_cell.shape)
means_cell.append(np.mean(cuts_cell, axis=0)) means_cell.append(np.mean(cuts_cell, axis=0))
if len(means_cell) == 0:
continue
means_cell = np.array(means_cell) means_cell = np.array(means_cell)
total_mean_cell = np.mean(means_cell, axis=0) total_mean_cell = np.mean(means_cell, axis=0)
axes[1].set_title("cell") axes[1].set_title("cell")
@ -100,6 +101,108 @@ def main():
plt.close() plt.close()
def sam_analysis(fit_path):
modelfit = get_best_fit(fit_path)
if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")):
print("Cell: {} \n Has no measured sam stimuli.")
return
cell_data = CellData(modelfit.get_cell_path())
model = modelfit.get_model()
# parameters = {'delta_a': 0.08820130374685671, 'refractory_period': 0.0006, 'a_zero': 15, 'step_size': 5e-05,
# 'v_base': 0, 'noise_strength': 0.03622523883042496, 'v_zero': 0, 'threshold': 1,
# 'input_scaling': 77.75367190909581, 'tau_a': 0.07623731247799118, 'v_offset': -10.546875,
# 'mem_tau': 0.008741976196676469, 'dend_tau': 0.0012058986118892773}
# model = LifacNoiseModel(parameters)
# cell_data = CellData("./data/test_data/2012-12-13-an-invivo-1/")
eod_freq = cell_data.get_eod_frequency()
step_size = cell_data.get_sampling_interval()
durations = cell_data.get_sam_durations()
u_durations = np.unique(durations)
contrasts = cell_data.get_sam_contrasts()
u_contrasts = np.unique(contrasts)
spiketimes = cell_data.get_sam_spiketimes()
delta_freqs = cell_data.get_sam_delta_frequencies()
u_delta_freqs = np.unique(delta_freqs)
all_data = []
for mod_freq in sorted(u_delta_freqs):
# TODO problem of cutting the pdf as in some cases the pdf is shorter than 1 modulation frequency period!
if 1/mod_freq > durations[0] / 4:
print("skipped mod_freq: {}".format(mod_freq))
print("Duration: {} while mod_freq period: {:.2f}".format(durations[0], 1/mod_freq))
print("Maybe long enough duration? unique durations:", u_durations)
continue
mfreq_data = {}
cell_means = []
model_means = []
for c in u_contrasts:
mfreq_data[c] = []
for i in range(len(delta_freqs)):
if delta_freqs[i] != mod_freq:
continue
if len(spiketimes[i]) > 1:
print("There are more spiketimes in one 'point'! Only the first was used! ")
spikes = spiketimes[i][0]
cell_pdf = spiketimes_calculate_pdf(spikes, step_size)
cell_cuts = cut_pdf_into_periods(cell_pdf, 1/mod_freq, step_size, factor=1.1, use_all=True)
cell_mean = np.mean(cell_cuts, axis=0)
cell_means.append(cell_mean)
# fig, axes = plt.subplots(1, 2)
# for c in cell_cuts:
# axes[0].plot(c, color="grey", alpha=0.2)
# axes[0].plot(np.mean(cell_means, axis=0), color="black")
stimulus = SAM(eod_freq, contrasts[i] / 100, mod_freq)
v1, spikes_model = model.simulate_fast(stimulus, durations[i] * 4)
model_pdf = spiketimes_calculate_pdf(spikes_model, step_size)
model_cuts = cut_pdf_into_periods(model_pdf, 1/mod_freq, step_size, factor=1.1)
model_mean = np.mean(model_cuts, axis=0)
model_means.append(model_mean)
# for c in model_cuts:
# axes[1].plot(c, color="grey", alpha=0.2)
# axes[1].plot(np.mean(model_cuts, axis=0), color="black")
# plt.title("mod_freq: {}".format(mod_freq))
# plt.show()
# plt.close()
fig, axes = plt.subplots(1, 4)
for c in cell_means:
axes[0].plot(c, color="grey", alpha=0.2)
axes[0].plot(np.mean(cell_means, axis=0), color="black")
axis_cell = axes[0].axis()
for m in model_means:
axes[1].plot(m, color="grey", alpha=0.2)
axes[1].plot(np.mean(model_means, axis=0), color="black")
axis_model = axes[1].axis()
ylim_top = max(axis_cell[3], axis_model[3])
axes[1].set_ylim(0, ylim_top)
axes[0].set_ylim(0, ylim_top)
axes[2].plot((np.mean(model_means, axis=0) - np.mean(cell_means, axis=0)) / np.mean(model_means, axis=0))
plt.title("modulation frequency: {}".format(mod_freq))
plt.show()
plt.close()
def generate_pdf(model, stimulus, trials=4, sim_length=3, kernel_width=0.005): def generate_pdf(model, stimulus, trials=4, sim_length=3, kernel_width=0.005):
trials_rate_list = [] trials_rate_list = []
@ -135,27 +238,28 @@ def spiketimes_calculate_pdf(spikes, step_size, kernel_width=0.005):
return rate return rate
def cut_pdf_into_periods(pdf, period, step_size, factor=1.5): def cut_pdf_into_periods(pdf, period, step_size, factor=1.5, use_all=False):
if period / step_size > len(pdf):
return [pdf]
idx_period_length = int(period/float(step_size)) idx_period_length = int(period/float(step_size))
offset_per_step = period/float(step_size) - idx_period_length offset_per_step = period/float(step_size) - idx_period_length
cut_length = int(period / float(step_size) * factor) cut_length = int(period / float(step_size) * factor)
cuts = [] num_of_cuts = int(len(pdf) / (idx_period_length+offset_per_step))
num_of_cuts = int(len(pdf) / idx_period_length)
if len(pdf) - (num_of_cuts * idx_period_length + (num_of_cuts * offset_per_step)) < cut_length - idx_period_length: if len(pdf) - (num_of_cuts * idx_period_length + (num_of_cuts * offset_per_step)) < cut_length - idx_period_length:
num_of_cuts -= 1 num_of_cuts -= 1
if num_of_cuts <= 0: if num_of_cuts <= 1:
raise RuntimeError("Probability density function to short to cut.") raise RuntimeError("Probability density function to short to cut.")
cuts = np.zeros((num_of_cuts-1, cut_length))
for i in np.arange(0, num_of_cuts, 1): for i in np.arange(1, num_of_cuts, 1):
offset_correction = int(offset_per_step * i) offset_correction = int(offset_per_step * i)
start_idx = i*idx_period_length + offset_correction start_idx = i*idx_period_length + offset_correction
end_idx = (i*idx_period_length)+cut_length + offset_correction end_idx = (i*idx_period_length)+cut_length + offset_correction
cuts.append(np.array(pdf[start_idx: end_idx])) cut = np.array(pdf[start_idx: end_idx])
cuts[i-1] = cut
cuts = np.array(cuts)
if len(cuts.shape) < 2: if len(cuts.shape) < 2:
print("Fishy....") print("Fishy....")