diff --git a/save_model_fits_as_csv.py b/save_model_fits_as_csv.py index 36fca43..0a87f73 100644 --- a/save_model_fits_as_csv.py +++ b/save_model_fits_as_csv.py @@ -1,33 +1,105 @@ from ModelFit import get_best_fit, ModelFit import os +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from Baseline import BaselineCellData + + +SAVE_DIR = "results/lab_rotation/" def main(): - dir = "results/final_1/" + res_folder = "results/final_2/" + + # save_model_parameters(res_folder) + # save_cell_info(res_folder) + + # test_save_cell_info() + +def save_model_parameters(res_folder): cells = [] eod_freqs = [] parameters = [] - for cell in sorted(os.listdir(dir)): - cell_dir = dir + cell + for cell in sorted(os.listdir(res_folder)): + cell_dir = res_folder + cell model = get_best_fit(cell_dir, use_comparable_error=False) cells.append(cell) eod_freqs.append(model.get_cell_data().get_eod_frequency()) parameters.append(model.get_final_parameters()) - save_csv(dir + "models.csv", cells, eod_freqs, parameters) + save_csv(SAVE_DIR + "models.csv", cells, eod_freqs, parameters) + + +def test_save_cell_info(): + for cell in sorted(os.listdir(SAVE_DIR)): + cell_dir = SAVE_DIR + cell + "/" + if not os.path.isdir(cell_dir): + continue + + fi_frame = pd.read_csv(cell_dir + "fi_curve_info.csv") + + plt.plot(fi_frame["contrast"], fi_frame["f_inf"], 'o') + plt.plot(fi_frame["contrast"], fi_frame["f_zero"], '+') + plt.show() + plt.close() + + count = 1 + spike_file = "baseline_spikes_trial_{}.npy".format(count) + while os.path.exists(cell_dir + spike_file): + spiketimes = np.load(cell_dir + spike_file) * 1000 + + plt.hist(np.diff(spiketimes), bins=np.arange(0, 50, 0.1)) + plt.show() + plt.close() + + count += 1 + spike_file = "baseline_spikes_trial_{}.npy".format(count) + + +def save_cell_info(res_folder): + for cell in sorted(os.listdir(res_folder)): + cell_dir = res_folder + cell + + fit = get_best_fit(cell_dir, use_comparable_error=False) + + save_path = SAVE_DIR + cell + "/" + + if not os.path.exists(save_path): + os.mkdir(save_path) + + # fi-curve + cell_data = fit.get_cell_data() + f_zeros = fit.get_cell_f_zero_values() + f_infs = fit.get_cell_f_inf_values() + contrasts = cell_data.get_fi_contrasts() + + data_array = np.array([contrasts, f_infs, f_zeros]).T + fi_frame = pd.DataFrame(data_array, columns=["contrast", "f_inf", "f_zero"]) + fi_frame.to_csv(save_path + "fi_curve_info.csv") + + spikes = cell_data.get_base_spikes() + for i, spike_list in enumerate(spikes): + spike_array = np.array(spike_list) + np.save(save_path + "baseline_spikes_trial_{}.npy".format(i+1), spike_array) def save_csv(file, cells, eod_freqs, parameters): keys = sorted(parameters[0].keys()) with open(file, "w") as file: - header = "cell,eod_frequency" + header = "cell,EODf" for k in keys: - header += ",{}".format(k) + if k == "refractory_period": + header += ",ref_period" + elif k == "step_size": + header += ",deltat" + else: + header += ",{}".format(k) file.write(header + "\n") for i in range(len(cells)):