From 5ce24e182dbd4c19e177d346e913f72f62b7b5f8 Mon Sep 17 00:00:00 2001 From: "a.ott" Date: Wed, 27 May 2020 09:14:19 +0200 Subject: [PATCH] add file to run Fitter add parallelization --- run_Fitter.py | 231 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 run_Fitter.py diff --git a/run_Fitter.py b/run_Fitter.py new file mode 100644 index 0000000..8d3263d --- /dev/null +++ b/run_Fitter.py @@ -0,0 +1,231 @@ + +from models.LIFACnoise import LifacNoiseModel +from CellData import CellData, icelldata_of_dir +from Baseline import get_baseline_class +from FiCurve import get_fi_curve_class +from Fitter import Fitter + +import time +import os + +import multiprocessing as mp + + +def main(): + count = 0 + for data in icelldata_of_dir("./data/"): + count += 1 + if count <= 3: + continue + trace = data.get_base_traces(trace_type=data.V1) + if len(trace) == 0: + print("NO V1 TRACE FOUND") + continue + fit_cell_parrallel(data, [p for p in iget_start_parameters()]) + + +def fit_cell_parrallel(cell_data, start_parameters): + cell_path = os.path.basename(cell_data.get_data_path()) + print(cell_path) + core_count = mp.cpu_count() + pool = mp.Pool(core_count - 3) + + fitter = Fitter() + fitter.set_data_reference_values(cell_data) + time1 = time.time() + outputs = pool.map(fitter.fit_routine_1, start_parameters) + time2 = time.time() + print("Time taken for all start parameters ({:}): {:.2f}s".format(len(start_parameters), time2-time1)) + for i, (fmin, fin_pars) in enumerate(outputs): + error = fitter.calculate_errors(model=LifacNoiseModel(fin_pars)) + print_comparision_cell_model(cell_data, fin_pars, plot=True, save_path="./test_routines/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(i+1, sum(error))) + + +def test_fit_routines(): + fitter = Fitter() + names = ("routine_1", "routine_2", "routine_3") + global FIT_ROUTINE + for i, routine in enumerate([fitter.fit_routine_1, fitter.fit_routine_2, fitter.fit_routine_3]): + FIT_ROUTINE = names[i] + run_with_real_data(fitter, routine) + + best = [] + cells = sorted(os.listdir("test_routines/" + names[0] + "/")) + for name in names: + + save_path = "test_routines/" + name + "/" + cell_best = [] + for directory in sorted(os.listdir(save_path)): + path = os.path.join(save_path, directory) + if os.path.isdir(path): + cell_best.append(find_best_run(path)) + best.append(cell_best) + + with open("test_routines/comparision.csv", "w") as res_file: + res_file.write("routine") + for cell in cells: + res_file.write("," + cell) + + for i, routine_results in enumerate(best): + res_file.write(names[i]) + for cell_best in routine_results: + res_file.write("," + str(cell_best)) + + +def find_best_run(cell_path): + values = [] + for directory in sorted(os.listdir(cell_path)): + start_par_path = os.path.join(cell_path, directory) + if os.path.isdir(start_par_path): + values.append(float(start_par_path.split("_")[-1])) + + return min(values) + + +def iget_start_parameters(): + # mem_tau, input_scaling, noise_strength, dend_tau, + # expand by tau_a, delta_a ? + + mem_tau_list = [0.01] + input_scaling_list = [40, 60] + noise_strength_list = [0.03] # [0.02, 0.06] + dend_tau_list = [0.001, 0.002] + delta_a_list = [0.035, 0.065] + + for mem_tau in mem_tau_list: + for input_scaling in input_scaling_list: + for noise_strength in noise_strength_list: + for dend_tau in dend_tau_list: + for delta_a in delta_a_list: + yield {"mem_tau": mem_tau, "input_scaling": input_scaling, + "noise_strength": noise_strength, "dend_tau": dend_tau, + "delta_a": delta_a} + + +def run_with_real_data(fitter, fit_routine_func, parallel=False): + count = 0 + for cell_data in icelldata_of_dir("./data/"): + count += 1 + if count < 7: + pass + #continue + + print("cell:", cell_data.get_data_path()) + trace = cell_data.get_base_traces(trace_type=cell_data.V1) + if len(trace) == 0: + print("NO V1 TRACE FOUND") + continue + + global FIT_ROUTINE + # results_path = "results/" + os.path.split(cell_data.get_data_path())[-1] + "/" + results_path = "test_routines/" + FIT_ROUTINE + "/" + os.path.split(cell_data.get_data_path())[-1] + "/" + print("results at:", results_path) + + if not os.path.exists(results_path): + os.makedirs(results_path) + + # plot cell images: + cell_save_path = results_path + "cell/" + if not os.path.exists(cell_save_path): + os.makedirs(cell_save_path) + data_baseline = get_baseline_class(cell_data) + data_baseline.plot_baseline(cell_save_path) + data_baseline.plot_interspike_interval_histogram(cell_save_path) + data_baseline.plot_serial_correlation(6, cell_save_path) + + data_fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts()) + data_fi_curve.plot_fi_curve(cell_save_path) + + start_par_count = 0 + for start_parameters in iget_start_parameters(): + start_par_count += 1 + print("START PARAMETERS:", start_par_count) + + start_time = time.time() + # fitter = Fitter() + fmin, parameters = fitter.fit_model_to_data(cell_data, start_parameters, fit_routine_func) + + print(fmin) + print(parameters) + end_time = time.time() + parameter_set_path = results_path + "start_par_set_{}_fmin_{:.2f}".format(start_par_count, fmin["fun"]) + "/" + if not os.path.exists(parameter_set_path): + os.makedirs(parameter_set_path) + with open(parameter_set_path + "parameters_info.txt".format(start_par_count), "w") as file: + file.writelines(["start_parameters:\t" + str(start_parameters), + "\nfinal_parameters:\t" + str(parameters), + "\nfinal_fmin:\t" + str(fmin)]) + + print('Fitting of cell took function took {:.3f} s'.format((end_time - start_time))) + # print(results_path) + print_comparision_cell_model(cell_data, parameters, + plot=True, save_path=parameter_set_path) + + # from Sounds import play_finished_sound + # play_finished_sound() + pass + + +def print_comparision_cell_model(cell_data, parameters, plot=False, save_path=None): + if save_path is not None: + if not os.path.exists(save_path): + os.makedirs(save_path) + model = LifacNoiseModel(parameters) + eod_frequency = cell_data.get_eod_frequency() + + model_baseline = get_baseline_class(model, eod_frequency) + m_bf = model_baseline.get_baseline_frequency() + m_vs = model_baseline.get_vector_strength() + m_sc = model_baseline.get_serial_correlation(1) + m_cv = model_baseline.get_coefficient_of_variation() + + model_ficurve = get_fi_curve_class(model, cell_data.get_fi_contrasts(), eod_frequency) + m_f_infinities = model_ficurve.get_f_inf_frequencies() + m_f_zeros = model_ficurve.get_f_zero_frequencies() + m_f_infinities_slope = model_ficurve.get_f_inf_slope() + m_f_zero_slope = model_ficurve.get_f_zero_fit_slope_at_straight() + + data_baseline = get_baseline_class(cell_data) + c_bf = data_baseline.get_baseline_frequency() + c_vs = data_baseline.get_vector_strength() + c_sc = data_baseline.get_serial_correlation(1) + c_cv = data_baseline.get_coefficient_of_variation() + + data_fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts()) + c_f_inf_slope = data_fi_curve.get_f_inf_slope() + c_f_inf_values = data_fi_curve.f_inf_frequencies + + c_f_zero_slope = data_fi_curve.get_f_zero_fit_slope_at_straight() + c_f_zero_values = data_fi_curve.f_zero_frequencies + # print("EOD-frequency: {:.2f}".format(cell_data.get_eod_frequency())) + # print("bf: cell - {:.2f} vs model {:.2f}".format(c_bf, m_bf)) + # print("vs: cell - {:.2f} vs model {:.2f}".format(c_vs, m_vs)) + # print("sc: cell - {:.2f} vs model {:.2f}".format(c_sc[0], m_sc[0])) + # print("cv: cell - {:.2f} vs model {:.2f}".format(c_cv, m_cv)) + # print("f_inf_slope: cell - {:.2f} vs model {:.2f}".format(c_f_inf_slope, m_f_infinities_slope)) + # print("f infinity values:\n cell -", c_f_inf_values, "\n model -", m_f_infinities) + # + # print("f_zero_slope: cell - {:.2f} vs model {:.2f}".format(c_f_zero_slope, m_f_zero_slope)) + # print("f zero values:\n cell -", c_f_zero_values, "\n model -", m_f_zeros) + if save_path is not None: + with open(save_path + "value_comparision.tsv", 'w') as value_file: + value_file.write("Variable\tCell\tModel\n") + value_file.write("baseline_frequency\t{:.2f}\t{:.2f}\n".format(c_bf, m_bf)) + value_file.write("vector_strength\t{:.2f}\t{:.2f}\n".format(c_vs, m_vs)) + value_file.write("serial_correlation\t{:.2f}\t{:.2f}\n".format(c_sc[0], m_sc[0])) + value_file.write("coefficient_of_variation\t{:.2f}\t{:.2f}\n".format(c_cv, m_cv)) + value_file.write("f_inf_slope\t{:.2f}\t{:.2f}\n".format(c_f_inf_slope, m_f_infinities_slope)) + value_file.write("f_zero_slope\t{:.2f}\t{:.2f}\n".format(c_f_zero_slope, m_f_zero_slope)) + + if plot: + # plot model images + model_baseline.plot_baseline(save_path) + model_baseline.plot_interspike_interval_histogram(save_path) + model_baseline.plot_serial_correlation(6, save_path) + + model_ficurve.plot_fi_curve(save_path) + model_ficurve.plot_fi_curve_comparision(data_fi_curve, model_ficurve, save_path) + + +if __name__ == '__main__': + main()