add and change functions to run parrelal fitting of cells with const refractory period
This commit is contained in:
parent
5a58483edb
commit
c6e07207f3
@ -1,12 +1,14 @@
|
|||||||
|
|
||||||
from models.LIFACnoise import LifacNoiseModel
|
from models.LIFACnoise import LifacNoiseModel
|
||||||
from CellData import icelldata_of_dir
|
from CellData import icelldata_of_dir, CellData
|
||||||
from Baseline import get_baseline_class
|
from Baseline import get_baseline_class
|
||||||
from FiCurve import get_fi_curve_class
|
from FiCurve import get_fi_curve_class
|
||||||
from Fitter import Fitter
|
from Fitter import Fitter
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import os
|
import os
|
||||||
|
import copy
|
||||||
|
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
|
|
||||||
@ -14,17 +16,54 @@ import multiprocessing as mp
|
|||||||
SAVE_PATH_PREFIX = ""
|
SAVE_PATH_PREFIX = ""
|
||||||
FIT_ROUTINE = ""
|
FIT_ROUTINE = ""
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
count = 0
|
|
||||||
for data in icelldata_of_dir("./data/"):
|
|
||||||
count += 1
|
|
||||||
# if count <= 3:
|
|
||||||
# continue
|
|
||||||
|
|
||||||
fit_cell_parrallel(data, [p for p in iget_start_parameters()])
|
test_effect_of_refractory_period()
|
||||||
|
|
||||||
|
quit()
|
||||||
|
cells = [data for data in icelldata_of_dir("./data/")]
|
||||||
|
|
||||||
|
start_parameter = [p for p in iget_start_parameters()]
|
||||||
|
|
||||||
|
fit_all_cells_parallel_sync(cells, start_parameter, )
|
||||||
|
|
||||||
|
|
||||||
|
def qfit_cell_base(parameter):
|
||||||
|
# parameter = (cell_data, start_parameter_index, start_parameter, results_base_folder)
|
||||||
|
time1 = time.time()
|
||||||
|
fitter = Fitter()
|
||||||
|
fitter.set_data_reference_values(parameter[0])
|
||||||
|
fmin, res_par = fitter.fit_routine_const_ref_period(parameter[2])
|
||||||
|
|
||||||
|
cell_data = parameter[0]
|
||||||
|
cell_path = os.path.split(cell_data.get_data_path())[-1]
|
||||||
|
|
||||||
|
error = fitter.calculate_errors(model=LifacNoiseModel(res_par))
|
||||||
|
save_path = parameter[3] + "/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(parameter[1], sum(error))
|
||||||
|
save_fitting_run_info(parameter[0], res_par, parameter[2], plot=True, save_path=save_path)
|
||||||
|
time2 = time.time()
|
||||||
|
|
||||||
|
print("Time taken for " + cell_path +
|
||||||
|
"\n and start parameters ({:}): {:.2f}s thread time".format(parameter[1]+1, time2 - time1) +
|
||||||
|
"\n error: {:.2f}".format(sum(error)))
|
||||||
|
|
||||||
def fit_cell_parrallel(cell_data, start_parameters):
|
|
||||||
|
def fit_all_cells_parallel_sync(cells, start_parameters, results_base_folder):
|
||||||
|
parameter = []
|
||||||
|
for cell in cells:
|
||||||
|
for i, s_pars in enumerate(start_parameters):
|
||||||
|
parameter.append((cell, i, s_pars, results_base_folder))
|
||||||
|
|
||||||
|
core_count = mp.cpu_count()
|
||||||
|
pool = mp.Pool(core_count - 1)
|
||||||
|
time1 = time.time()
|
||||||
|
pool.map(fit_cell_base, parameter)
|
||||||
|
time2 = time.time()
|
||||||
|
print("Time taken for all cells and start parameters ({:}): {:.2f}s".format(len(parameter), time2 - time1))
|
||||||
|
|
||||||
|
|
||||||
|
def fit_cell_parallel(cell_data, start_parameters):
|
||||||
cell_path = os.path.basename(cell_data.get_data_path())
|
cell_path = os.path.basename(cell_data.get_data_path())
|
||||||
print(cell_path)
|
print(cell_path)
|
||||||
core_count = mp.cpu_count()
|
core_count = mp.cpu_count()
|
||||||
@ -89,7 +128,7 @@ def iget_start_parameters():
|
|||||||
# expand by tau_a, delta_a ?
|
# expand by tau_a, delta_a ?
|
||||||
|
|
||||||
mem_tau_list = [0.01]
|
mem_tau_list = [0.01]
|
||||||
input_scaling_list = [40, 60]
|
input_scaling_list = [60]
|
||||||
noise_strength_list = [0.03] # [0.02, 0.06]
|
noise_strength_list = [0.03] # [0.02, 0.06]
|
||||||
dend_tau_list = [0.001, 0.002]
|
dend_tau_list = [0.001, 0.002]
|
||||||
delta_a_list = [0.035, 0.065]
|
delta_a_list = [0.035, 0.065]
|
||||||
@ -233,5 +272,18 @@ def save_fitting_run_info(cell_data, parameters, start_parameters, plot=False, s
|
|||||||
model_ficurve.plot_fi_curve_comparision(data_fi_curve, model_ficurve, save_path)
|
model_ficurve.plot_fi_curve_comparision(data_fi_curve, model_ficurve, save_path)
|
||||||
|
|
||||||
|
|
||||||
|
def test_effect_of_refractory_period():
|
||||||
|
ref_periods = np.arange(0.0006, 0.001, 0.0015)
|
||||||
|
|
||||||
|
cells = [c for c in icelldata_of_dir("./data/")]
|
||||||
|
start_parameters_base = [p for p in iget_start_parameters()]
|
||||||
|
for ref_period in ref_periods:
|
||||||
|
results_base_folder = "./test_routines/ref_period_{:.3f}/".format(ref_period)
|
||||||
|
all_start_parameters = copy.deepcopy(start_parameters_base)
|
||||||
|
|
||||||
|
for par_set in all_start_parameters:
|
||||||
|
par_set["refractory_period"] = ref_period
|
||||||
|
fit_all_cells_parallel_sync(cells, all_start_parameters, results_base_folder)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
Loading…
Reference in New Issue
Block a user