import os from parser.CellData import CellData from parser.DataParserFactory import DatParser import numpy as np from fitting.ModelFit import ModelFit, get_best_fit # from plottools.axes import labelaxes_params import matplotlib.pyplot as plt from run_Fitter import iget_start_parameters from experiments.FiCurve import FICurve, FICurveCellData, FICurveModel from Figures_results import scatter_hist from my_util.functions import exponential_function colors = ["black", "red", "blue", "orange", "green"] def main(): # results_dir = "data/final/" # for folder in sorted(os.listdir(results_dir)): # folder_path = os.path.join(results_dir, folder) # # if not os.path.isdir(folder_path): # continue # # cell_data = CellData(folder_path) # cell_name = cell_data.get_cell_name() # # fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path) # # fi_cell.plot_fi_curve(title=cell_name, save_path="temp/cell_fi_curves_images/" + cell_name + "_") # # steady_state = fi_cell.get_f_inf_frequencies() # onset = fi_cell.get_f_zero_frequencies() # baseline = fi_cell.get_f_baseline_frequencies() # contrasts = fi_cell.stimulus_values # # headers = ["contrasts", "f_baseline", "f_steady_state", "f_onset"] # with open("temp/cell_fi_curves_csvs/" + cell_name + ".csv", 'w') as f: # for i in range(len(headers)): # if i == 0: # f.write(headers[i]) # else: # f.write("," + headers[i]) # f.write("\n") # # for i in range(len(contrasts)): # f.write(str(contrasts[i]) + ",") # f.write(str(baseline[i]) + ",") # f.write(str(steady_state[i]) + ",") # f.write(str(onset[i]) + "\n") # quit() step_response_comparison() quit() cell_taus = [] model_taus = [] results_dir = "results/sam_cells_only_best/" for folder in sorted(os.listdir(results_dir)): folder_path = os.path.join(results_dir, folder) if not os.path.isdir(folder_path): continue fit = get_best_fit(folder_path) print(fit.get_fit_routine_error()) model = fit.get_model() cell_data = fit.get_cell_data() fi_model = FICurveModel(model, cell_data.get_fi_contrasts(), cell_data.get_eod_frequency(), save_dir=folder_path) tau_model = fi_model.calculate_time_constant(-2) model_taus.append(tau_model[1]) fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path) tau_cell = fi_cell.calculate_time_constant(-2) cell_taus.append(tau_cell[1]) model_taus_c = [] cell_taus_c = [] border = 1 for i in range(len(model_taus)): if np.abs(model_taus[i]) < border and np.abs(cell_taus[i]) < border: model_taus_c.append(model_taus[i]) cell_taus_c.append(cell_taus[i]) print("model removed:", len(model_taus) - len(model_taus_c)) print("cell removed:", len(cell_taus) - len(cell_taus_c)) plot_cell_model_comp_taus(cell_taus_c, model_taus_c) # fig, axes = plt.subplots(1, 2, sharey="all", sharex="all") # # axes[0].hist(model_taus_c) # axes[0].set_title("Model taus") # # axes[1].hist(cell_taus_c) # axes[1].set_title("Cell taus") # # plt.show() # plt.close() print(model_taus) print(cell_taus) # sam_tests() def step_response_comparison(): results_dir = "results/sam_cells_only_best/" for folder in sorted(os.listdir(results_dir)): folder_path = os.path.join(results_dir, folder) if not os.path.isdir(folder_path): continue fit = get_best_fit(folder_path) print(fit.get_fit_routine_error()) model = fit.get_model() cell_data = fit.get_cell_data() fi_model = FICurveModel(model, cell_data.get_fi_contrasts(), cell_data.get_eod_frequency(), save_dir=folder_path) # model_times, model_mean_freqs = fi_model.get_mean_time_and_freq_traces() fi_cell = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), cell_data.data_path) contrasts = cell_data.get_fi_contrasts() mean_frequencies = cell_data.get_mean_fi_curve_isi_frequencies() baseline_freqs = fi_cell.get_f_baseline_frequencies() pre_duration = -1 * cell_data.get_recording_times()[0] sampling_interval = cell_data.get_sampling_interval() fig, axes = plt.subplots(2, 2, figsize=(12, 8)) c_contrast_idx = -2 tau_params, tau_x = fi_cell.__calculate_time_constant_internal__(contrasts[c_contrast_idx], mean_frequencies[c_contrast_idx], baseline_freqs[c_contrast_idx], sampling_interval, pre_duration, plot=False, plot_data=True) # cell x_values = np.arange(len(mean_frequencies[c_contrast_idx])) * sampling_interval - pre_duration index_range = (x_values > -0.05) & (x_values < 0.15) plot_freq_with_tau_fit(axes[0, 0], x_values[index_range], np.array(mean_frequencies[c_contrast_idx])[index_range], tau_x, tau_params) axes[0, 0].set_title("2nd highest Contrast: {:.2f}".format(contrasts[c_contrast_idx])) axes[0, 0].set_ylabel("Cell") axes[0, 0].set_xlabel("Time [s]; tau: {:.3f}".format(tau_params[1])) c_contrast_idx = -3 tau_params, tau_x = fi_cell.__calculate_time_constant_internal__(contrasts[c_contrast_idx], mean_frequencies[c_contrast_idx], baseline_freqs[c_contrast_idx], sampling_interval, pre_duration, plot=False, plot_data=True) x_values = np.arange(len(mean_frequencies[c_contrast_idx])) * sampling_interval - pre_duration index_range = (x_values > -0.05) & (x_values < 0.15) plot_freq_with_tau_fit(axes[0, 1], x_values[index_range], np.array(mean_frequencies[c_contrast_idx])[index_range], tau_x, tau_params) axes[0, 1].set_title("3rd highest Contrast: {:.2f}".format(contrasts[c_contrast_idx])) axes[0, 1].set_xlabel("Time [s]; tau: {:.3f}".format(tau_params[1])) # model contrasts = cell_data.get_fi_contrasts() mean_frequencies = fi_model.mean_frequency_traces baseline_freqs = fi_model.get_f_baseline_frequencies() pre_duration = 0.5 sampling_interval = fi_model.model.get_sampling_interval() c_contrast_idx = -2 tau_params, tau_x = fi_cell.__calculate_time_constant_internal__(contrasts[c_contrast_idx], mean_frequencies[c_contrast_idx], baseline_freqs[c_contrast_idx], sampling_interval, pre_duration, plot=False, plot_data=True) x_values = np.arange(len(mean_frequencies[c_contrast_idx])) * sampling_interval - pre_duration index_range = (x_values > -0.05) & (x_values < 0.15) plot_freq_with_tau_fit(axes[1, 0], x_values[index_range], np.array(mean_frequencies[c_contrast_idx])[index_range], tau_x, tau_params) axes[1, 0].set_ylabel("Model") axes[1, 0].set_xlabel("Time [s]; tau: {:.3f}".format(tau_params[1])) c_contrast_idx = -3 tau_params, tau_x = fi_cell.__calculate_time_constant_internal__(contrasts[c_contrast_idx], mean_frequencies[c_contrast_idx], baseline_freqs[c_contrast_idx], sampling_interval, pre_duration, plot=False, plot_data=True) x_values = np.arange(len(mean_frequencies[c_contrast_idx])) * sampling_interval - pre_duration index_range = (x_values > -0.05) & (x_values < 0.15) plot_freq_with_tau_fit(axes[1, 1], x_values[index_range], np.array(mean_frequencies[c_contrast_idx])[index_range], tau_x, tau_params) axes[1, 1].set_xlabel("Time [s]; tau: {:.3f}".format(tau_params[1])) axes[0, 0].set_ylim((0, 800)) axes[0, 1].set_ylim((0, 800)) axes[1, 1].set_ylim((0, 800)) axes[1, 0].set_ylim((0, 800)) plt.tight_layout() plt.savefig("figures/tau_images/" + cell_data.get_cell_name() + "_tau.png") plt.close() def plot_freq_with_tau_fit(ax, time, freq, tau_x, tau_params): ax.plot(time, freq) ax.plot(tau_x, exponential_function(tau_x, tau_params[0], tau_params[1], tau_params[2])) def plot_cell_model_comp_taus(cell_taus, model_taus): fig = plt.figure(figsize=(3, 4)) gs = fig.add_gridspec(2, 1, height_ratios=[3, 7], left=0.1, right=0.95, bottom=0.1, top=0.9, wspace=0.4, hspace=0.2) num_of_bins = 20 minimum = min(min(cell_taus), min(model_taus)) maximum = max(max(cell_taus), max(model_taus)) step = (maximum - minimum) / num_of_bins bins = np.arange(minimum, maximum + step, step) ax = fig.add_subplot(gs[1, 0]) ax_histx = fig.add_subplot(gs[0, 0], sharex=ax) scatter_hist(cell_taus, model_taus, ax, ax_histx, "Tau Comparison", bins) # , cmap, cell_bursting) ax.set_xlabel(r"Cell [s]") ax.set_ylabel(r"Model [s]") ax_histx.set_ylabel("Count") plt.tight_layout() plt.savefig("figures/tau_images/fit_tau_comparison.pdf", transparent=True) plt.close() def sam_tests(): data_folder = "./data/final_sam/" for cell in sorted(os.listdir(data_folder)): print(cell) cell_folder = os.path.join(data_folder, cell) if not os.path.exists(os.path.join(cell_folder, "samspikes1.dat")): continue if "2018-05-08-aa-invivo-1" not in cell: continue cell_data = CellData(cell_folder) sampling_rate = int(round(1 / cell_data.get_sampling_interval())) sam_spikes = cell_data.get_sam_spiketimes() delta_freqs = cell_data.get_sam_delta_frequencies() [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] = cell_data.get_sam_traces() print(len(time_traces)) for i in range(len(delta_freqs)): if abs(delta_freqs[i]) > 50: continue fig, axes = plt.subplots(2, 1, sharex="all") axes[0].plot(time_traces[i], local_eod_traces[i]) axes[0].set_title("Local EOD - dF {}".format(delta_freqs[i])) axes[1].plot(time_traces[i], v1_traces[i]) axes[1].set_title("v1 trace") ah_spike = average_spike_height(sam_spikes, v1_traces[i], sampling_rate) for j, idx in enumerate(get_x_best(ah_spike)): axes[1].eventplot(sam_spikes[idx], lineoffsets=max(v1_traces[i] + 1.5 * (j + 1)), colors=colors[j % len(colors)]) plt.show() plt.close() def average_spike_height(spike_trains, local_eod, sampling_rate): average_height = [] for spikes_train in spike_trains: indices = np.array([s * sampling_rate for s in spikes_train[0]], dtype=np.int) local_eod = np.array(local_eod) spike_values = [local_eod[i] for i in indices if i < len(local_eod)] average_height.append(np.mean(spike_values)) return average_height def get_x_best(average_heights, x=5): biggest_idx = [] biggest_heights = [] for i, height in enumerate(average_heights): if len(biggest_idx) < x: biggest_idx.append(i) biggest_heights.append(height) elif height > min(biggest_heights): mini = np.argmin(biggest_heights) biggest_heights[mini] = height biggest_idx[mini] = i biggest_heights, biggest_idx = (list(t) for t in zip(*sorted(zip(biggest_heights, biggest_idx), reverse=True))) print(biggest_heights) return biggest_idx if __name__ == '__main__': main()