save more data after Fit

This commit is contained in:
a.ott 2020-07-05 11:06:05 +02:00
parent 574f4a80f2
commit 222af65bd6

View File

@ -4,11 +4,13 @@ from CellData import icelldata_of_dir, CellData
from Baseline import get_baseline_class from Baseline import get_baseline_class
from FiCurve import get_fi_curve_class from FiCurve import get_fi_curve_class
from Fitter import Fitter from Fitter import Fitter
from ModelFit import ModelFit
import time import time
import os import os
import copy import copy
import argparse import argparse
import numpy as np
import multiprocessing as mp import multiprocessing as mp
@ -32,13 +34,16 @@ def main():
# quit() # quit()
start_parameters = [p for p in iget_start_parameters()] start_parameters = [p for p in iget_start_parameters()]
start_data = 8 # start_data = 8
count = 0 # count = 0
for cell_data in icelldata_of_dir("./invivo_data/"): # for cell_data in icelldata_of_dir("./invivo_data/"):
count += 1 # count += 1
if count < start_data: # if count < start_data:
continue # continue
fit_cell_parallel(cell_data, start_parameters) # fit_cell_parallel(cell_data, start_parameters)
cell_data = CellData("invivo_data/2012-04-20-ab-invivo-1/")
fit_cell_parallel(cell_data, start_parameters)
def test_single_cell(path): def test_single_cell(path):
@ -94,13 +99,15 @@ def fit_all_cells_parallel_sync(cells, start_parameters, thread_pool, results_ba
def fit_cell_parallel(cell_data, start_parameters): def fit_cell_parallel(cell_data, start_parameters):
cell_path = os.path.basename(cell_data.get_data_path()) cell_path = os.path.basename(cell_data.get_data_path())
save_directory = "./results/invivo_results/"
save_path_cell = os.path.join(save_directory, cell_data.get_cell_name())
print(cell_path) print(cell_path)
core_count = mp.cpu_count() core_count = mp.cpu_count()
pool = mp.Pool(core_count - 1) pool = mp.Pool(core_count - 1)
parameters = [] parameters = []
for i, p in enumerate(start_parameters): for i, p in enumerate(start_parameters):
parameters.append((cell_data, i, p, "./results/invivo_results/")) parameters.append((cell_data, i, p, save_directory))
time1 = time.time() time1 = time.time()
pool.map(fit_cell_base, parameters) pool.map(fit_cell_base, parameters)
@ -109,6 +116,17 @@ def fit_cell_parallel(cell_data, start_parameters):
del pool del pool
del cell_data del cell_data
best_fit = None
min_err = np.inf
for fit in os.listdir(save_path_cell):
cur_fit = ModelFit(os.path.join(save_path_cell, fit))
if cur_fit.comparable_error() < min_err:
min_err = cur_fit.comparable_error()
best_fit = cur_fit
best_fit.generate_master_plot("./results/invivo_best/singles/")
def test_fit_routines(): def test_fit_routines():
fitter = Fitter() fitter = Fitter()
@ -238,6 +256,9 @@ def save_fitting_run_info(cell_data, parameters, start_parameters, plot=False, s
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
if save_path is None:
return
with open(save_path + "parameters_info.txt", "w") as file: with open(save_path + "parameters_info.txt", "w") as file:
file.writelines(["start_parameters:\t" + str(start_parameters), file.writelines(["start_parameters:\t" + str(start_parameters),
"\nfinal_parameters:\t" + str(parameters)]) "\nfinal_parameters:\t" + str(parameters)])
@ -245,6 +266,24 @@ def save_fitting_run_info(cell_data, parameters, start_parameters, plot=False, s
model = LifacNoiseModel(parameters) model = LifacNoiseModel(parameters)
eod_frequency = cell_data.get_eod_frequency() eod_frequency = cell_data.get_eod_frequency()
data_baseline = get_baseline_class(cell_data)
c_bf = data_baseline.get_baseline_frequency()
c_vs = data_baseline.get_vector_strength()
c_sc = data_baseline.get_serial_correlation(1)
c_cv = data_baseline.get_coefficient_of_variation()
c_burst = data_baseline.get_burstiness()
data_fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts())
c_f_inf_slope = data_fi_curve.get_f_inf_slope()
c_f_inf_values = data_fi_curve.f_inf_frequencies
c_f_zero_slope = data_fi_curve.get_f_zero_fit_slope_at_straight()
c_f_zero_values = data_fi_curve.f_zero_frequencies
if c_f_inf_slope < 0:
contrasts = np.array(cell_data.get_fi_contrasts()) * -1
else:
contrasts = np.array(cell_data.get_fi_contrasts())
model_baseline = get_baseline_class(model, eod_frequency) model_baseline = get_baseline_class(model, eod_frequency)
m_bf = model_baseline.get_baseline_frequency() m_bf = model_baseline.get_baseline_frequency()
m_vs = model_baseline.get_vector_strength() m_vs = model_baseline.get_vector_strength()
@ -258,19 +297,18 @@ def save_fitting_run_info(cell_data, parameters, start_parameters, plot=False, s
m_f_infinities_slope = model_ficurve.get_f_inf_slope() m_f_infinities_slope = model_ficurve.get_f_inf_slope()
m_f_zero_slope = model_ficurve.get_f_zero_fit_slope_at_straight() m_f_zero_slope = model_ficurve.get_f_zero_fit_slope_at_straight()
data_baseline = get_baseline_class(cell_data) np.save(os.path.join(save_path, "model_fi_inf_values.npy"), np.array(m_f_infinities))
c_bf = data_baseline.get_baseline_frequency() np.save(os.path.join(save_path, "cell_fi_inf_values.npy"), np.array(c_f_inf_values))
c_vs = data_baseline.get_vector_strength() np.save(os.path.join(save_path, "model_fi_zero_values.npy"), np.array(m_f_zeros))
c_sc = data_baseline.get_serial_correlation(1) np.save(os.path.join(save_path, "cell_fi_zero_values.npy"), np.array(c_f_zero_values))
c_cv = data_baseline.get_coefficient_of_variation()
c_burst = data_baseline.get_burstiness()
data_fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts()) with open(os.path.join(save_path, "cell_data_path.txt"), "w") as f:
c_f_inf_slope = data_fi_curve.get_f_inf_slope() path = cell_data.get_data_path() + "\n"
c_f_inf_values = data_fi_curve.f_inf_frequencies f.write(path)
if c_f_inf_slope < 0:
model_ficurve.stimulus_values = contrasts * -1
c_f_zero_slope = data_fi_curve.get_f_zero_fit_slope_at_straight()
c_f_zero_values = data_fi_curve.f_zero_frequencies
# print("EOD-frequency: {:.2f}".format(cell_data.get_eod_frequency())) # print("EOD-frequency: {:.2f}".format(cell_data.get_eod_frequency()))
# print("bf: cell - {:.2f} vs model {:.2f}".format(c_bf, m_bf)) # print("bf: cell - {:.2f} vs model {:.2f}".format(c_bf, m_bf))
# print("vs: cell - {:.2f} vs model {:.2f}".format(c_vs, m_vs)) # print("vs: cell - {:.2f} vs model {:.2f}".format(c_vs, m_vs))