P-unit_model/save_model_fits_as_csv.py
2021-05-22 13:10:15 +02:00

116 lines
3.2 KiB
Python

from fitting.ModelFit import get_best_fit
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
SAVE_DIR = "results/sam_cells_only_best/"
def main():
res_folder = "results/sam_cells_only_best/"
save_model_parameters(res_folder)
# save_cell_info(res_folder)
# test_save_cell_info()
def save_model_parameters(res_folder):
cells = []
eod_freqs = []
parameters = []
for cell in sorted(os.listdir(res_folder)):
cell_dir = res_folder + cell
model = get_best_fit(cell_dir, use_comparable_error=False)
cells.append(cell)
eod_freqs.append(model.get_cell_data().get_eod_frequency())
parameters.append(model.get_final_parameters())
save_csv(SAVE_DIR + "models.csv", cells, eod_freqs, parameters)
def test_save_cell_info():
for cell in sorted(os.listdir(SAVE_DIR)):
cell_dir = SAVE_DIR + cell + "/"
if not os.path.isdir(cell_dir):
continue
fi_frame = pd.read_csv(cell_dir + "fi_curve_info.csv")
plt.plot(fi_frame["contrast"], fi_frame["f_inf"], 'o')
plt.plot(fi_frame["contrast"], fi_frame["f_zero"], '+')
plt.show()
plt.close()
count = 1
spike_file = "baseline_spikes_trial_{}.npy".format(count)
while os.path.exists(cell_dir + spike_file):
spiketimes = np.load(cell_dir + spike_file) * 1000
plt.hist(np.diff(spiketimes), bins=np.arange(0, 50, 0.1))
plt.show()
plt.close()
count += 1
spike_file = "baseline_spikes_trial_{}.npy".format(count)
def save_cell_info(res_folder):
for cell in sorted(os.listdir(res_folder)):
cell_dir = res_folder + cell
fit = get_best_fit(cell_dir, use_comparable_error=False)
save_path = SAVE_DIR + cell + "/"
if not os.path.exists(save_path):
os.mkdir(save_path)
# fi-curve
cell_data = fit.get_cell_data()
f_zeros = fit.get_cell_f_zero_values()
f_infs = fit.get_cell_f_inf_values()
contrasts = cell_data.get_fi_contrasts()
data_array = np.array([contrasts, f_infs, f_zeros]).T
fi_frame = pd.DataFrame(data_array, columns=["contrast", "f_inf", "f_zero"])
fi_frame.to_csv(save_path + "fi_curve_info.csv")
spikes = cell_data.get_base_spikes()
for i, spike_list in enumerate(spikes):
spike_array = np.array(spike_list)
np.save(save_path + "baseline_spikes_trial_{}.npy".format(i+1), spike_array)
def save_csv(file, cells, eod_freqs, parameters):
keys = sorted(parameters[0].keys())
with open(file, "w") as file:
header = "cell,EODf"
for k in keys:
if k == "refractory_period":
header += ",ref_period"
elif k == "step_size":
header += ",deltat"
else:
header += ",{}".format(k)
file.write(header + "\n")
for i in range(len(cells)):
line = "{},{:.2f}".format(cells[i], eod_freqs[i])
for k in keys:
line += ",{}".format(parameters[i][k])
file.write(line + "\n")
if __name__ == '__main__':
main()