save more info for lab rotation
This commit is contained in:
parent
24e91a5601
commit
aacdac9aad
@ -1,33 +1,105 @@
|
|||||||
|
|
||||||
from ModelFit import get_best_fit, ModelFit
|
from ModelFit import get_best_fit, ModelFit
|
||||||
import os
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from Baseline import BaselineCellData
|
||||||
|
|
||||||
|
|
||||||
|
SAVE_DIR = "results/lab_rotation/"
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
dir = "results/final_1/"
|
res_folder = "results/final_2/"
|
||||||
|
|
||||||
|
# save_model_parameters(res_folder)
|
||||||
|
# save_cell_info(res_folder)
|
||||||
|
|
||||||
|
# test_save_cell_info()
|
||||||
|
|
||||||
|
|
||||||
|
def save_model_parameters(res_folder):
|
||||||
cells = []
|
cells = []
|
||||||
eod_freqs = []
|
eod_freqs = []
|
||||||
parameters = []
|
parameters = []
|
||||||
|
|
||||||
for cell in sorted(os.listdir(dir)):
|
for cell in sorted(os.listdir(res_folder)):
|
||||||
cell_dir = dir + cell
|
cell_dir = res_folder + cell
|
||||||
|
|
||||||
model = get_best_fit(cell_dir, use_comparable_error=False)
|
model = get_best_fit(cell_dir, use_comparable_error=False)
|
||||||
cells.append(cell)
|
cells.append(cell)
|
||||||
eod_freqs.append(model.get_cell_data().get_eod_frequency())
|
eod_freqs.append(model.get_cell_data().get_eod_frequency())
|
||||||
parameters.append(model.get_final_parameters())
|
parameters.append(model.get_final_parameters())
|
||||||
|
|
||||||
save_csv(dir + "models.csv", cells, eod_freqs, parameters)
|
save_csv(SAVE_DIR + "models.csv", cells, eod_freqs, parameters)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_cell_info():
|
||||||
|
for cell in sorted(os.listdir(SAVE_DIR)):
|
||||||
|
cell_dir = SAVE_DIR + cell + "/"
|
||||||
|
if not os.path.isdir(cell_dir):
|
||||||
|
continue
|
||||||
|
|
||||||
|
fi_frame = pd.read_csv(cell_dir + "fi_curve_info.csv")
|
||||||
|
|
||||||
|
plt.plot(fi_frame["contrast"], fi_frame["f_inf"], 'o')
|
||||||
|
plt.plot(fi_frame["contrast"], fi_frame["f_zero"], '+')
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
count = 1
|
||||||
|
spike_file = "baseline_spikes_trial_{}.npy".format(count)
|
||||||
|
while os.path.exists(cell_dir + spike_file):
|
||||||
|
spiketimes = np.load(cell_dir + spike_file) * 1000
|
||||||
|
|
||||||
|
plt.hist(np.diff(spiketimes), bins=np.arange(0, 50, 0.1))
|
||||||
|
plt.show()
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
count += 1
|
||||||
|
spike_file = "baseline_spikes_trial_{}.npy".format(count)
|
||||||
|
|
||||||
|
|
||||||
|
def save_cell_info(res_folder):
|
||||||
|
for cell in sorted(os.listdir(res_folder)):
|
||||||
|
cell_dir = res_folder + cell
|
||||||
|
|
||||||
|
fit = get_best_fit(cell_dir, use_comparable_error=False)
|
||||||
|
|
||||||
|
save_path = SAVE_DIR + cell + "/"
|
||||||
|
|
||||||
|
if not os.path.exists(save_path):
|
||||||
|
os.mkdir(save_path)
|
||||||
|
|
||||||
|
# fi-curve
|
||||||
|
cell_data = fit.get_cell_data()
|
||||||
|
f_zeros = fit.get_cell_f_zero_values()
|
||||||
|
f_infs = fit.get_cell_f_inf_values()
|
||||||
|
contrasts = cell_data.get_fi_contrasts()
|
||||||
|
|
||||||
|
data_array = np.array([contrasts, f_infs, f_zeros]).T
|
||||||
|
fi_frame = pd.DataFrame(data_array, columns=["contrast", "f_inf", "f_zero"])
|
||||||
|
fi_frame.to_csv(save_path + "fi_curve_info.csv")
|
||||||
|
|
||||||
|
spikes = cell_data.get_base_spikes()
|
||||||
|
for i, spike_list in enumerate(spikes):
|
||||||
|
spike_array = np.array(spike_list)
|
||||||
|
np.save(save_path + "baseline_spikes_trial_{}.npy".format(i+1), spike_array)
|
||||||
|
|
||||||
|
|
||||||
def save_csv(file, cells, eod_freqs, parameters):
|
def save_csv(file, cells, eod_freqs, parameters):
|
||||||
keys = sorted(parameters[0].keys())
|
keys = sorted(parameters[0].keys())
|
||||||
with open(file, "w") as file:
|
with open(file, "w") as file:
|
||||||
header = "cell,eod_frequency"
|
header = "cell,EODf"
|
||||||
for k in keys:
|
for k in keys:
|
||||||
header += ",{}".format(k)
|
if k == "refractory_period":
|
||||||
|
header += ",ref_period"
|
||||||
|
elif k == "step_size":
|
||||||
|
header += ",deltat"
|
||||||
|
else:
|
||||||
|
header += ",{}".format(k)
|
||||||
file.write(header + "\n")
|
file.write(header + "\n")
|
||||||
|
|
||||||
for i in range(len(cells)):
|
for i in range(len(cells)):
|
||||||
|
Loading…
Reference in New Issue
Block a user