from fitting.ModelFit import get_best_fit import os import pandas as pd import numpy as np import matplotlib.pyplot as plt SAVE_DIR = "results/sam_cells_only_best/" def main(): res_folder = "results/sam_cells_only_best/" save_model_parameters(res_folder) # save_cell_info(res_folder) # test_save_cell_info() def save_model_parameters(res_folder): cells = [] eod_freqs = [] parameters = [] for cell in sorted(os.listdir(res_folder)): cell_dir = res_folder + cell model = get_best_fit(cell_dir, use_comparable_error=False) cells.append(cell) eod_freqs.append(model.get_cell_data().get_eod_frequency()) parameters.append(model.get_final_parameters()) save_csv(SAVE_DIR + "models.csv", cells, eod_freqs, parameters) def test_save_cell_info(): for cell in sorted(os.listdir(SAVE_DIR)): cell_dir = SAVE_DIR + cell + "/" if not os.path.isdir(cell_dir): continue fi_frame = pd.read_csv(cell_dir + "fi_curve_info.csv") plt.plot(fi_frame["contrast"], fi_frame["f_inf"], 'o') plt.plot(fi_frame["contrast"], fi_frame["f_zero"], '+') plt.show() plt.close() count = 1 spike_file = "baseline_spikes_trial_{}.npy".format(count) while os.path.exists(cell_dir + spike_file): spiketimes = np.load(cell_dir + spike_file) * 1000 plt.hist(np.diff(spiketimes), bins=np.arange(0, 50, 0.1)) plt.show() plt.close() count += 1 spike_file = "baseline_spikes_trial_{}.npy".format(count) def save_cell_info(res_folder): for cell in sorted(os.listdir(res_folder)): cell_dir = res_folder + cell fit = get_best_fit(cell_dir, use_comparable_error=False) save_path = SAVE_DIR + cell + "/" if not os.path.exists(save_path): os.mkdir(save_path) # fi-curve cell_data = fit.get_cell_data() f_zeros = fit.get_cell_f_zero_values() f_infs = fit.get_cell_f_inf_values() contrasts = cell_data.get_fi_contrasts() data_array = np.array([contrasts, f_infs, f_zeros]).T fi_frame = pd.DataFrame(data_array, columns=["contrast", "f_inf", "f_zero"]) fi_frame.to_csv(save_path + "fi_curve_info.csv") spikes = cell_data.get_base_spikes() for i, spike_list in enumerate(spikes): spike_array = np.array(spike_list) np.save(save_path + "baseline_spikes_trial_{}.npy".format(i+1), spike_array) def save_csv(file, cells, eod_freqs, parameters): keys = sorted(parameters[0].keys()) with open(file, "w") as file: header = "cell,EODf" for k in keys: if k == "refractory_period": header += ",ref_period" elif k == "step_size": header += ",deltat" else: header += ",{}".format(k) file.write(header + "\n") for i in range(len(cells)): line = "{},{:.2f}".format(cells[i], eod_freqs[i]) for k in keys: line += ",{}".format(parameters[i][k]) file.write(line + "\n") if __name__ == '__main__': main()