diff --git a/run_Fitter.py b/run_Fitter.py index 4893a22..fd378c8 100644 --- a/run_Fitter.py +++ b/run_Fitter.py @@ -1,12 +1,14 @@ from models.LIFACnoise import LifacNoiseModel -from CellData import icelldata_of_dir +from CellData import icelldata_of_dir, CellData from Baseline import get_baseline_class from FiCurve import get_fi_curve_class from Fitter import Fitter +import numpy as np import time import os +import copy import multiprocessing as mp @@ -14,17 +16,54 @@ import multiprocessing as mp SAVE_PATH_PREFIX = "" FIT_ROUTINE = "" + def main(): - count = 0 - for data in icelldata_of_dir("./data/"): - count += 1 - # if count <= 3: - # continue - fit_cell_parrallel(data, [p for p in iget_start_parameters()]) + test_effect_of_refractory_period() + + quit() + cells = [data for data in icelldata_of_dir("./data/")] + + start_parameter = [p for p in iget_start_parameters()] + + fit_all_cells_parallel_sync(cells, start_parameter, ) + + +def qfit_cell_base(parameter): + # parameter = (cell_data, start_parameter_index, start_parameter, results_base_folder) + time1 = time.time() + fitter = Fitter() + fitter.set_data_reference_values(parameter[0]) + fmin, res_par = fitter.fit_routine_const_ref_period(parameter[2]) + + cell_data = parameter[0] + cell_path = os.path.split(cell_data.get_data_path())[-1] + + error = fitter.calculate_errors(model=LifacNoiseModel(res_par)) + save_path = parameter[3] + "/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(parameter[1], sum(error)) + save_fitting_run_info(parameter[0], res_par, parameter[2], plot=True, save_path=save_path) + time2 = time.time() + + print("Time taken for " + cell_path + + "\n and start parameters ({:}): {:.2f}s thread time".format(parameter[1]+1, time2 - time1) + + "\n error: {:.2f}".format(sum(error))) + + +def fit_all_cells_parallel_sync(cells, start_parameters, results_base_folder): + parameter = [] + for cell in cells: + for i, s_pars in enumerate(start_parameters): + parameter.append((cell, i, s_pars, results_base_folder)) + + core_count = mp.cpu_count() + pool = mp.Pool(core_count - 1) + time1 = time.time() + pool.map(fit_cell_base, parameter) + time2 = time.time() + print("Time taken for all cells and start parameters ({:}): {:.2f}s".format(len(parameter), time2 - time1)) -def fit_cell_parrallel(cell_data, start_parameters): +def fit_cell_parallel(cell_data, start_parameters): cell_path = os.path.basename(cell_data.get_data_path()) print(cell_path) core_count = mp.cpu_count() @@ -71,7 +110,7 @@ def test_fit_routines(): for i, routine_results in enumerate(best): res_file.write(names[i]) for cell_best in routine_results: - res_file.write("," + str(cell_best)) + res_file.write("," + str(cell_best)) def find_best_run(cell_path): @@ -89,7 +128,7 @@ def iget_start_parameters(): # expand by tau_a, delta_a ? mem_tau_list = [0.01] - input_scaling_list = [40, 60] + input_scaling_list = [60] noise_strength_list = [0.03] # [0.02, 0.06] dend_tau_list = [0.001, 0.002] delta_a_list = [0.035, 0.065] @@ -233,5 +272,18 @@ def save_fitting_run_info(cell_data, parameters, start_parameters, plot=False, s model_ficurve.plot_fi_curve_comparision(data_fi_curve, model_ficurve, save_path) +def test_effect_of_refractory_period(): + ref_periods = np.arange(0.0006, 0.001, 0.0015) + + cells = [c for c in icelldata_of_dir("./data/")] + start_parameters_base = [p for p in iget_start_parameters()] + for ref_period in ref_periods: + results_base_folder = "./test_routines/ref_period_{:.3f}/".format(ref_period) + all_start_parameters = copy.deepcopy(start_parameters_base) + + for par_set in all_start_parameters: + par_set["refractory_period"] = ref_period + fit_all_cells_parallel_sync(cells, all_start_parameters, results_base_folder) + if __name__ == '__main__': main()