From 2e6a4af7fd59f7a0ebd1ce691f325df8beb56c19 Mon Sep 17 00:00:00 2001 From: "a.ott" Date: Thu, 23 Jul 2020 10:35:06 +0200 Subject: [PATCH] adapt for ModelFit --- sam_experiments.py | 154 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 129 insertions(+), 25 deletions(-) diff --git a/sam_experiments.py b/sam_experiments.py index 12f0dac..2bdfcd5 100644 --- a/sam_experiments.py +++ b/sam_experiments.py @@ -7,25 +7,21 @@ import numpy as np import matplotlib.pyplot as plt import helperFunctions as hF from CellData import CellData - +from ModelFit import ModelFit, get_best_fit +import os def main(): - # 2012-07-12-ag-invivo-1 fit and eod frequency: - # parameters = {'refractory_period': 0.00080122694889117, 'v_base': 0, 'v_zero': 0, 'a_zero': 20, 'step_size': 5e-05, - # 'delta_a': 0.23628384937392385, 'threshold': 1, 'input_scaling': 100.66894113671654, - # 'mem_tau': 0.012388673630113763, 'tau_a': 0.09106579031822526, 'v_offset': -6.25, - # 'noise_strength': 0.0404417932620334, 'dend_tau': 0.00122153436141022} - # cell_data = CellData("./data/2012-07-12-ag-invivo-1/") - - parameters = {'delta_a': 0.08820130374685671, 'refractory_period': 0.0006, 'a_zero': 15, 'step_size': 5e-05, - 'v_base': 0, 'noise_strength': 0.03622523883042496, 'v_zero': 0, 'threshold': 1, - 'input_scaling': 77.75367190909581, 'tau_a': 0.07623731247799118, 'v_offset': -10.546875, - 'mem_tau': 0.008741976196676469, 'dend_tau': 0.0012058986118892773} + sam_analysis("results/invivo_results/2013-01-08-ad-invivo-1/") + quit() + modelfit = get_best_fit("results/invivo_results/2013-01-08-ad-invivo-1/") - cell_data = CellData("./data/2012-12-13-an-invivo-1/") + if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")): + print("Cell: {} \n Has no measured sam stimuli.") + return + cell_data = CellData(modelfit.get_cell_path()) eod_freq = cell_data.get_eod_frequency() - model = LifacNoiseModel(parameters) + model = modelfit.get_model() # base_cell = get_baseline_class(cell_data) # base_model = get_baseline_class(model, cell_data.get_eod_frequency()) @@ -45,6 +41,7 @@ def main(): u_durations = np.unique(durations) mean_duration = np.mean(durations) contrasts = cell_data.get_sam_contrasts() + u_contrasts = np.unique(contrasts) contrast = contrasts[0] # are all the same in this test case spiketimes = cell_data.get_sam_spiketimes() delta_freqs = cell_data.get_sam_delta_frequencies() @@ -66,6 +63,7 @@ def main(): # plt.plot(prob_density_function_model) # plt.show() # plt.close() + fig, axes = plt.subplots(1, 4) cuts = cut_pdf_into_periods(prob_density_function_model, 1/float(m_freq), step_size) for c in cuts: @@ -78,12 +76,15 @@ def main(): for spikes_cell in spikes_dictionary[m_freq]: prob_density_cell = spiketimes_calculate_pdf(spikes_cell[0], step_size) + if len(prob_density_cell) < 3 * (eod_freq / step_size): + continue cuts_cell = cut_pdf_into_periods(prob_density_cell, 1/float(m_freq), step_size) for c in cuts_cell: axes[1].plot(c, color="gray", alpha=0.15) print(cuts_cell.shape) means_cell.append(np.mean(cuts_cell, axis=0)) - + if len(means_cell) == 0: + continue means_cell = np.array(means_cell) total_mean_cell = np.mean(means_cell, axis=0) axes[1].set_title("cell") @@ -100,6 +101,108 @@ def main(): plt.close() +def sam_analysis(fit_path): + modelfit = get_best_fit(fit_path) + + if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")): + print("Cell: {} \n Has no measured sam stimuli.") + return + cell_data = CellData(modelfit.get_cell_path()) + model = modelfit.get_model() + + # parameters = {'delta_a': 0.08820130374685671, 'refractory_period': 0.0006, 'a_zero': 15, 'step_size': 5e-05, + # 'v_base': 0, 'noise_strength': 0.03622523883042496, 'v_zero': 0, 'threshold': 1, + # 'input_scaling': 77.75367190909581, 'tau_a': 0.07623731247799118, 'v_offset': -10.546875, + # 'mem_tau': 0.008741976196676469, 'dend_tau': 0.0012058986118892773} + # model = LifacNoiseModel(parameters) + # cell_data = CellData("./data/test_data/2012-12-13-an-invivo-1/") + + eod_freq = cell_data.get_eod_frequency() + step_size = cell_data.get_sampling_interval() + + durations = cell_data.get_sam_durations() + u_durations = np.unique(durations) + contrasts = cell_data.get_sam_contrasts() + u_contrasts = np.unique(contrasts) + spiketimes = cell_data.get_sam_spiketimes() + delta_freqs = cell_data.get_sam_delta_frequencies() + u_delta_freqs = np.unique(delta_freqs) + + all_data = [] + for mod_freq in sorted(u_delta_freqs): + # TODO problem of cutting the pdf as in some cases the pdf is shorter than 1 modulation frequency period! + + if 1/mod_freq > durations[0] / 4: + print("skipped mod_freq: {}".format(mod_freq)) + print("Duration: {} while mod_freq period: {:.2f}".format(durations[0], 1/mod_freq)) + print("Maybe long enough duration? unique durations:", u_durations) + continue + mfreq_data = {} + cell_means = [] + model_means = [] + for c in u_contrasts: + mfreq_data[c] = [] + + for i in range(len(delta_freqs)): + if delta_freqs[i] != mod_freq: + continue + + if len(spiketimes[i]) > 1: + print("There are more spiketimes in one 'point'! Only the first was used! ") + spikes = spiketimes[i][0] + + + cell_pdf = spiketimes_calculate_pdf(spikes, step_size) + + cell_cuts = cut_pdf_into_periods(cell_pdf, 1/mod_freq, step_size, factor=1.1, use_all=True) + cell_mean = np.mean(cell_cuts, axis=0) + cell_means.append(cell_mean) + # fig, axes = plt.subplots(1, 2) + # for c in cell_cuts: + # axes[0].plot(c, color="grey", alpha=0.2) + # axes[0].plot(np.mean(cell_means, axis=0), color="black") + + stimulus = SAM(eod_freq, contrasts[i] / 100, mod_freq) + v1, spikes_model = model.simulate_fast(stimulus, durations[i] * 4) + model_pdf = spiketimes_calculate_pdf(spikes_model, step_size) + model_cuts = cut_pdf_into_periods(model_pdf, 1/mod_freq, step_size, factor=1.1) + model_mean = np.mean(model_cuts, axis=0) + model_means.append(model_mean) + + # for c in model_cuts: + # axes[1].plot(c, color="grey", alpha=0.2) + # axes[1].plot(np.mean(model_cuts, axis=0), color="black") + # plt.title("mod_freq: {}".format(mod_freq)) + # plt.show() + # plt.close() + + + fig, axes = plt.subplots(1, 4) + for c in cell_means: + axes[0].plot(c, color="grey", alpha=0.2) + axes[0].plot(np.mean(cell_means, axis=0), color="black") + axis_cell = axes[0].axis() + + for m in model_means: + axes[1].plot(m, color="grey", alpha=0.2) + axes[1].plot(np.mean(model_means, axis=0), color="black") + axis_model = axes[1].axis() + ylim_top = max(axis_cell[3], axis_model[3]) + axes[1].set_ylim(0, ylim_top) + axes[0].set_ylim(0, ylim_top) + + axes[2].plot((np.mean(model_means, axis=0) - np.mean(cell_means, axis=0)) / np.mean(model_means, axis=0)) + + plt.title("modulation frequency: {}".format(mod_freq)) + plt.show() + plt.close() + + + + + + + def generate_pdf(model, stimulus, trials=4, sim_length=3, kernel_width=0.005): trials_rate_list = [] @@ -135,27 +238,28 @@ def spiketimes_calculate_pdf(spikes, step_size, kernel_width=0.005): return rate -def cut_pdf_into_periods(pdf, period, step_size, factor=1.5): +def cut_pdf_into_periods(pdf, period, step_size, factor=1.5, use_all=False): + + if period / step_size > len(pdf): + return [pdf] + idx_period_length = int(period/float(step_size)) offset_per_step = period/float(step_size) - idx_period_length cut_length = int(period / float(step_size) * factor) - cuts = [] - - num_of_cuts = int(len(pdf) / idx_period_length) + num_of_cuts = int(len(pdf) / (idx_period_length+offset_per_step)) if len(pdf) - (num_of_cuts * idx_period_length + (num_of_cuts * offset_per_step)) < cut_length - idx_period_length: num_of_cuts -= 1 - if num_of_cuts <= 0: + if num_of_cuts <= 1: raise RuntimeError("Probability density function to short to cut.") - - for i in np.arange(0, num_of_cuts, 1): + cuts = np.zeros((num_of_cuts-1, cut_length)) + for i in np.arange(1, num_of_cuts, 1): offset_correction = int(offset_per_step * i) start_idx = i*idx_period_length + offset_correction end_idx = (i*idx_period_length)+cut_length + offset_correction - cuts.append(np.array(pdf[start_idx: end_idx])) - - cuts = np.array(cuts) + cut = np.array(pdf[start_idx: end_idx]) + cuts[i-1] = cut if len(cuts.shape) < 2: print("Fishy....")