from stimuli.SinusAmplitudeModulation import SinusAmplitudeModulationStimulus as SAM from models.LIFACnoise import LifacNoiseModel import numpy as np import matplotlib.pyplot as plt from my_util import helperFunctions as hF from parser.CellData import CellData from fitting.ModelFit import get_best_fit import os def main(): run_sam_analysis_for_all_cells("results/final_sam") # sam_analysis("results/final_2/2011-10-25-ad-invivo-1/") # plot_traces_with_spiketimes() # plot_mean_of_cuts() quit() modelfit = get_best_fit("results/final_2/2011-10-25-ad-invivo-1/") cell_data = CellData(modelfit.get_cell_path()) eod_freq = cell_data.get_eod_frequency() model = modelfit.get_model() test_model_response(model, eod_freq, 0.1, np.arange(5, 2500, 5)) def run_sam_analysis_for_all_cells(folder): count = 0 for item in os.listdir(folder): cell_folder = os.path.join(folder, item) # fit = get_best_fit(cell_folder, use_comparable_error=False) # cell_data = fit.get_cell_data() # # if cell_data.has_sam_recordings(): # count += 1 # # print("Fit quality:", fit.get_fit_routine_error()) # sam_analysis(cell_folder) sam_analysis(cell_folder) print(count) def test_model_response(model: LifacNoiseModel, eod_freq, contrast, modulation_frequencies): stds = [] for m_freq in modulation_frequencies: if (1/m_freq) / 10 <= model.parameters["step_size"]: model.parameters["step_size"] = (1/m_freq) / 10 step_size = model.parameters["step_size"] print("mode_freq:", m_freq, "- step size:", step_size) stimulus = SAM(eod_freq, contrast / 100, m_freq) duration = 30 v1, spikes_model = model.simulate(stimulus, duration) prob_density_function_model = spiketimes_calculate_pdf(spikes_model, step_size, kernel_width=0.005) fig, ax = plt.subplots(1, 1) ax.plot(prob_density_function_model) ax.set_title("pdf with m_freq: {}".format(int(m_freq))) plt.savefig("figures/sam/pdf_mfreq_{}.png".format(m_freq)) plt.close() stds.append(np.std(prob_density_function_model)) plt.plot((np.array(modulation_frequencies)) / eod_freq, stds) plt.show() plt.close() def plot_traces_with_spiketimes(): modelfit = get_best_fit("results/final_2/2011-10-25-ad-invivo-1/") cell_data = modelfit.get_cell_data() traces = cell_data.parser.__get_traces__("SAM") # [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] sam_spiketimes = cell_data.get_sam_spiketimes() for i in range(len(traces[0])): fig, axes = plt.subplots(2, 1, sharex=True) axes[0].plot(traces[0][i], traces[1][i]) axes[0].plot(list(sam_spiketimes[i]), list([max(traces[1][i])] * len(sam_spiketimes[i])), 'o') axes[1].plot(traces[0][i], traces[3][i]) plt.show() plt.close() def plot_mean_of_cuts(): modelfit = get_best_fit("results/final_2/2018-05-08-ac-invivo-1/") if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")): print("Cell: {} \n Has no measured sam stimuli.") return cell_data = CellData(modelfit.get_cell_path()) eod_freq = cell_data.get_eod_frequency() model = modelfit.get_model() durations = cell_data.get_sam_durations() u_durations = np.unique(durations) mean_duration = np.mean(durations) contrasts = cell_data.get_sam_contrasts() u_contrasts = np.unique(contrasts) contrast = contrasts[0] # are all the same in this test case spiketimes = cell_data.get_sam_spiketimes() delta_freqs = cell_data.get_sam_delta_frequencies() step_size = cell_data.get_sampling_interval() spikes_dictionary = {} for i, m_freq in enumerate(delta_freqs): if m_freq in spikes_dictionary: spikes_dictionary[m_freq].append(spiketimes[i]) else: spikes_dictionary[m_freq] = [spiketimes[i]] for m_freq in sorted(spikes_dictionary.keys()): if mean_duration < 2 * (1 / float(m_freq)): print("meep") continue stimulus = SAM(eod_freq, contrast / 100, m_freq) v1, spikes_model = model.simulate(stimulus, 4) prob_density_function_model = spiketimes_calculate_pdf(spikes_model, step_size) fig, axes = plt.subplots(1, 4) start_idx = int(2 / step_size) cuts = cut_pdf_into_periods(prob_density_function_model[start_idx:], 1 / float(m_freq), step_size) for c in cuts: axes[0].plot(c, color="gray", alpha=0.2) axes[0].set_title("model") mean_model = np.mean(cuts, axis=0) axes[0].plot(mean_model, color="black") means_cell = [] for spikes_cell in spikes_dictionary[m_freq]: prob_density_cell = spiketimes_calculate_pdf(spikes_cell[0], step_size) cuts_cell = cut_pdf_into_periods(prob_density_cell, 1 / float(m_freq), step_size) for c in cuts_cell: axes[1].plot(c, color="gray", alpha=0.15) print(cuts_cell.shape) means_cell.append(np.mean(cuts_cell, axis=0)) if len(means_cell) == 0: print("means cell length zero") continue means_cell = np.array(means_cell) total_mean_cell = np.mean(means_cell, axis=0) axes[1].set_title("cell") axes[1].plot(total_mean_cell, color="black") axes[2].set_title("difference") diff = [(total_mean_cell[i] - mean_model[i]) for i in range(len(total_mean_cell))] axes[2].plot(diff) axes[3].plot(total_mean_cell) axes[3].plot(mean_model) plt.show() plt.close() def sam_analysis(fit_path): modelfit = get_best_fit(fit_path) # if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")): # print("Cell: {} \n Has no measured sam stimuli.") # return cell_data_path = modelfit.get_cell_path() if "final_sam" in cell_data_path: cell_data_path = cell_data_path.replace("final_sam", "final") cell_data = CellData(cell_data_path) model = modelfit.get_model() # parameters = {'delta_a': 0.08820130374685671, 'refractory_period': 0.0006, 'a_zero': 15, 'step_size': 5e-05, # 'v_base': 0, 'noise_strength': 0.03622523883042496, 'v_zero': 0, 'threshold': 1, # 'input_scaling': 77.75367190909581, 'tau_a': 0.07623731247799118, 'v_offset': -10.546875, # 'mem_tau': 0.008741976196676469, 'dend_tau': 0.0012058986118892773} # model = LifacNoiseModel(parameters) # cell_data = CellData("./data/test_data/2012-12-13-an-invivo-1/") eod_freq = cell_data.get_eod_frequency() step_size = cell_data.get_sampling_interval() durations = cell_data.get_sam_durations() u_durations = np.unique(durations) contrasts = cell_data.get_sam_contrasts() u_contrasts = np.unique(contrasts) spiketimes = cell_data.get_sam_spiketimes() delta_freqs = cell_data.get_sam_delta_frequencies() u_delta_freqs = np.unique(delta_freqs) cell_stds = [] model_stds = [] approx_offset = approximate_axon_delay_in_idx(cell_data, model) print("Approx offset idx:", approx_offset) print("Approx offset ms:", (approx_offset * step_size) * 1000) for mod_freq in sorted(u_delta_freqs): # TODO problem of cutting the pdf as in some cases the pdf is shorter than 1 modulation frequency period! # length info wrong ? always at least one period? # if 1/mod_freq > durations[0] / 4: # print("skipped mod_freq: {}".format(mod_freq)) # print("Duration: {} while mod_freq period: {:.2f}".format(durations[0], 1/mod_freq)) # print("Maybe long enough duration? unique durations:", u_durations) # continue mfreq_data = {} cell_means = [] model_means = [] for c in u_contrasts: mfreq_data[c] = [] for i in range(len(delta_freqs)): if delta_freqs[i] != mod_freq: continue if len(spiketimes[i]) == 0: print("No spiketimes found at index!") continue if len(spiketimes[i]) > 1: print("There are more spiketimes in one 'point'! Only the first was used! ") spikes = spiketimes[i][0] cell_pdf = spiketimes_calculate_pdf(spikes, step_size) cell_cuts = cut_pdf_into_periods(cell_pdf, 1/mod_freq, step_size) cell_mean = np.mean(cell_cuts, axis=0) cell_means.append(cell_mean) stimulus = SAM(eod_freq, contrasts[i] / 100, mod_freq) v1, spikes_model = model.simulate(stimulus, 10) model_pdf = spiketimes_calculate_pdf(spikes_model, step_size) model_cuts = cut_pdf_into_periods(model_pdf, 1/mod_freq, step_size) model_mean = np.mean(model_cuts, axis=0) model_means.append(model_mean) min_length = min(min([len(cm) for cm in cell_means]), min([len(mm) for mm in model_means])) for i in range(len(cell_means)): cell_means[i] = cell_means[i][:min_length] model_means[i] = model_means[i][:min_length] final_cell_mean = np.mean(cell_means, axis=0) final_model_mean = np.mean(model_means, axis=0) cell_stds.append(np.std(final_cell_mean)) model_stds.append(np.std(final_model_mean)) # offset, final_model_mean_phase_corrected = correct_phase(final_cell_mean, final_model_mean, step_size) # print("Offset:", offset) # print("modfreq:", mod_freq) final_model_mean_phase_corrected = np.roll(final_model_mean, approx_offset) # PLOT EVERY MOD FREQ # fig, axes = plt.subplots(1, 5, figsize=(15, 5), sharex=True) # for c in cell_means: # axes[0].plot(c, color="grey", alpha=0.2) # axes[0].plot(np.mean(cell_means, axis=0), color="black") # axes[0].set_title("Cell response") # axis_cell = axes[0].axis() # # for m in model_means: # axes[1].plot(m, color="grey", alpha=0.2) # axes[1].plot(np.mean(model_means, axis=0), color="black") # axes[1].set_title("Model response") # axis_model = axes[1].axis() # ylim_top = max(axis_cell[3], axis_model[3]) # axes[1].set_ylim(0, ylim_top) # axes[0].set_ylim(0, ylim_top) # axes[2].set_ylim(0, ylim_top) # # axes[2].plot(final_cell_mean, label="cell") # axes[2].plot(final_model_mean, label="model") # axes[2].plot(final_model_mean_phase_corrected, label="model p-cor") # axes[2].legend() # axes[2].set_title("cell-model overlapped") # axes[3].plot((final_model_mean - final_cell_mean) / final_cell_mean, label="normal") # axes[3].plot((final_model_mean_phase_corrected- final_cell_mean) / final_cell_mean, label="phase cor") # axes[3].set_title("rel. error") # axes[3].legend() # axes[4].plot(final_model_mean - final_cell_mean, label="normal") # axes[4].plot(final_model_mean_phase_corrected - final_cell_mean, label="phase cor") # axes[4].set_title("abs. error (Hz)") # axes[4].legend() # # fig.suptitle("modulation frequency: {}".format(mod_freq)) # # # plt.tight_layout() # # plt.show() # plt.close() fig, ax = plt.subplots(1, 1) ax.plot(u_delta_freqs[-len(cell_stds):], cell_stds, label="cell stds") ax.plot(u_delta_freqs[-len(model_stds):], model_stds, label="model stds") ax.set_title("response modulation depth") ax.set_xlabel("Modulation frequency") ax.set_ylabel("STD") ax.legend() plt.savefig("figures/sam/" + cell_data.get_cell_name() + ".png") # plt.show() plt.close() def correct_phase(cell_mean, model_mean, step_size): # test for every 0.2 ms roll in the total time: lowest_err = np.inf roll_idx = 0 for i in range(int(len(cell_mean) * step_size * 1000) * 5): roll_by = int((i / 5 / 1000) / step_size) rolled = np.roll(model_mean, roll_by) # rms = np.sqrt(np.mean(np.power((cell_mean - rolled), 2))) abs = np.sum(np.abs(cell_mean-rolled)) if abs < lowest_err: lowest_err = abs roll_idx = roll_by return roll_idx, np.roll(model_mean, roll_idx) def approximate_axon_delay_in_idx(cell_data, model): lowest_mod_freq = 80 highest_mod_freq = 150 eod_freq = cell_data.get_eod_frequency() step_size = cell_data.get_sampling_interval() durations = cell_data.get_sam_durations() contrasts = cell_data.get_sam_contrasts() u_contrasts = np.unique(contrasts) spiketimes = cell_data.get_sam_spiketimes() delta_freqs = cell_data.get_sam_delta_frequencies() u_delta_freqs = np.unique(delta_freqs) used_mod_freqs = [] axon_delays = [] for mod_freq in sorted(u_delta_freqs): # Only use "stable" mod_freqs to approximate the axon delay if not lowest_mod_freq <= mod_freq <= highest_mod_freq: continue if 1/mod_freq > durations[0] / 4: print("skipped mod_freq: {}".format(mod_freq)) print("Duration: {} while mod_freq period: {:.2f}".format(durations[0], 1/mod_freq)) continue mfreq_data = {} cell_means = [] model_means = [] for c in u_contrasts: mfreq_data[c] = [] for i in range(len(delta_freqs)): if delta_freqs[i] != mod_freq: continue if len(spiketimes[i]) > 1: print("There are more spiketimes in one 'point'! Only the first was used! ") spikes = spiketimes[i][0] cell_pdf = spiketimes_calculate_pdf(spikes, step_size) cell_cuts = cut_pdf_into_periods(cell_pdf, 1/mod_freq, step_size) if len(cell_cuts) == 0: continue cell_mean = np.mean(cell_cuts, axis=0) cell_means.append(cell_mean) stimulus = SAM(eod_freq, contrasts[i] / 100, mod_freq) v1, spikes_model = model.simulate(stimulus, durations[i] * 4) model_pdf = spiketimes_calculate_pdf(spikes_model, step_size) model_cuts = cut_pdf_into_periods(model_pdf, 1/mod_freq, step_size) model_mean = np.mean(model_cuts, axis=0) model_means.append(model_mean) final_cell_mean = np.mean(cell_means, axis=0) final_model_mean = np.mean(model_means, axis=0) offset, final_model_mean_phase_corrected = correct_phase(final_cell_mean, final_model_mean, step_size) used_mod_freqs.append(mod_freq) axon_delays.append(offset) mean_delay = np.mean(axon_delays) if np.isnan(mean_delay): return 0 else: return int(round(mean_delay)) def generate_pdf(model, stimulus, trials=4, sim_length=3, kernel_width=0.005): trials_rate_list = [] step_size = model.get_parameters()["step_size"] for _ in range(trials): v1, spikes = model.simulate_slow(stimulus, total_time_s=sim_length) binary = np.zeros(int(sim_length/step_size)) spikes = [int(s / step_size) for s in spikes] for s_idx in spikes: binary[s_idx] = 1 kernel = gaussian_kernel(kernel_width, step_size) rate = np.convolve(binary, kernel, mode='same') trials_rate_list.append(rate) times = [np.arange(0, sim_length, step_size) for _ in range(trials)] t, mean_rate = hF.calculate_mean_of_frequency_traces(times, trials_rate_list, step_size) return mean_rate def spiketimes_calculate_pdf(spikes, step_size, kernel_width=0.001): length = int(spikes[len(spikes)-1] / step_size)+1 binary = np.zeros(length) spikes = [int(s / step_size) for s in spikes] for s_idx in spikes: binary[s_idx] = 1 kernel = gaussian_kernel(kernel_width, step_size) rate = np.convolve(binary, kernel, mode='same') return rate def cut_pdf_into_periods(pdf, period, step_size, factor=0.0): if period < 0: # print("cut_pdf_into_periods(): Period was negative! Absolute value taken to continue") period = abs(period) idx_period_length = int(period / float(step_size)) offset_per_step = period / float(step_size) - idx_period_length cut_length = idx_period_length + int(factor * idx_period_length) num_of_cuts = int(len(pdf) / (idx_period_length + offset_per_step)) if len(pdf) - (num_of_cuts * idx_period_length + (num_of_cuts * offset_per_step)) < cut_length - idx_period_length: num_of_cuts -= 1 if idx_period_length * 0.9 > len(pdf): return [] # raise RuntimeError("SAM stimulus is too short for the given mod freq period.") if cut_length > len(pdf) or num_of_cuts < 1: return [pdf] cuts = np.zeros((num_of_cuts-1, cut_length)) for i in np.arange(1, num_of_cuts, 1): offset_correction = int(offset_per_step * i) start_idx = i*idx_period_length + offset_correction end_idx = (i*idx_period_length)+cut_length + offset_correction cut = np.array(pdf[start_idx: end_idx]) cuts[i-1] = cut if len(cuts.shape) < 2: print("Fishy....") return cuts def gaussian_kernel(sigma, dt): x = np.arange(-4. * sigma, 4. * sigma, dt) y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma return y if __name__ == '__main__': main()