add and change functions to run parrelal fitting of cells with const refractory period

This commit is contained in:
alexanderott 2020-06-01 12:19:31 +02:00
parent 5a58483edb
commit c6e07207f3

View File

@ -1,12 +1,14 @@
from models.LIFACnoise import LifacNoiseModel from models.LIFACnoise import LifacNoiseModel
from CellData import icelldata_of_dir from CellData import icelldata_of_dir, CellData
from Baseline import get_baseline_class from Baseline import get_baseline_class
from FiCurve import get_fi_curve_class from FiCurve import get_fi_curve_class
from Fitter import Fitter from Fitter import Fitter
import numpy as np
import time import time
import os import os
import copy
import multiprocessing as mp import multiprocessing as mp
@ -14,17 +16,54 @@ import multiprocessing as mp
SAVE_PATH_PREFIX = "" SAVE_PATH_PREFIX = ""
FIT_ROUTINE = "" FIT_ROUTINE = ""
def main(): def main():
count = 0
for data in icelldata_of_dir("./data/"):
count += 1
# if count <= 3:
# continue
fit_cell_parrallel(data, [p for p in iget_start_parameters()]) test_effect_of_refractory_period()
quit()
cells = [data for data in icelldata_of_dir("./data/")]
start_parameter = [p for p in iget_start_parameters()]
fit_all_cells_parallel_sync(cells, start_parameter, )
def qfit_cell_base(parameter):
# parameter = (cell_data, start_parameter_index, start_parameter, results_base_folder)
time1 = time.time()
fitter = Fitter()
fitter.set_data_reference_values(parameter[0])
fmin, res_par = fitter.fit_routine_const_ref_period(parameter[2])
cell_data = parameter[0]
cell_path = os.path.split(cell_data.get_data_path())[-1]
error = fitter.calculate_errors(model=LifacNoiseModel(res_par))
save_path = parameter[3] + "/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(parameter[1], sum(error))
save_fitting_run_info(parameter[0], res_par, parameter[2], plot=True, save_path=save_path)
time2 = time.time()
print("Time taken for " + cell_path +
"\n and start parameters ({:}): {:.2f}s thread time".format(parameter[1]+1, time2 - time1) +
"\n error: {:.2f}".format(sum(error)))
def fit_cell_parrallel(cell_data, start_parameters):
def fit_all_cells_parallel_sync(cells, start_parameters, results_base_folder):
parameter = []
for cell in cells:
for i, s_pars in enumerate(start_parameters):
parameter.append((cell, i, s_pars, results_base_folder))
core_count = mp.cpu_count()
pool = mp.Pool(core_count - 1)
time1 = time.time()
pool.map(fit_cell_base, parameter)
time2 = time.time()
print("Time taken for all cells and start parameters ({:}): {:.2f}s".format(len(parameter), time2 - time1))
def fit_cell_parallel(cell_data, start_parameters):
cell_path = os.path.basename(cell_data.get_data_path()) cell_path = os.path.basename(cell_data.get_data_path())
print(cell_path) print(cell_path)
core_count = mp.cpu_count() core_count = mp.cpu_count()
@ -89,7 +128,7 @@ def iget_start_parameters():
# expand by tau_a, delta_a ? # expand by tau_a, delta_a ?
mem_tau_list = [0.01] mem_tau_list = [0.01]
input_scaling_list = [40, 60] input_scaling_list = [60]
noise_strength_list = [0.03] # [0.02, 0.06] noise_strength_list = [0.03] # [0.02, 0.06]
dend_tau_list = [0.001, 0.002] dend_tau_list = [0.001, 0.002]
delta_a_list = [0.035, 0.065] delta_a_list = [0.035, 0.065]
@ -233,5 +272,18 @@ def save_fitting_run_info(cell_data, parameters, start_parameters, plot=False, s
model_ficurve.plot_fi_curve_comparision(data_fi_curve, model_ficurve, save_path) model_ficurve.plot_fi_curve_comparision(data_fi_curve, model_ficurve, save_path)
def test_effect_of_refractory_period():
ref_periods = np.arange(0.0006, 0.001, 0.0015)
cells = [c for c in icelldata_of_dir("./data/")]
start_parameters_base = [p for p in iget_start_parameters()]
for ref_period in ref_periods:
results_base_folder = "./test_routines/ref_period_{:.3f}/".format(ref_period)
all_start_parameters = copy.deepcopy(start_parameters_base)
for par_set in all_start_parameters:
par_set["refractory_period"] = ref_period
fit_all_cells_parallel_sync(cells, all_start_parameters, results_base_folder)
if __name__ == '__main__': if __name__ == '__main__':
main() main()