diff --git a/Fitter.py b/Fitter.py index 7d1e325..baa48e5 100644 --- a/Fitter.py +++ b/Fitter.py @@ -132,8 +132,6 @@ class Fitter: args=(error_weights,), x0=x0, method="Nelder-Mead", options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400}) - plot_errors(self.errors) - return fmin, self.base_model.get_parameters() def cost_function_all(self, X, error_weights=None): @@ -197,10 +195,10 @@ class Fitter: # calculate errors with reference values error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq) - error_vs = abs((vector_strength - self.vector_strength) / 0.1) + error_vs = abs((vector_strength - self.vector_strength) / 0.01) error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.2) error_bursty = (abs(burstiness - self.burstiness) / 0.2) - error_hist = np.mean((isi_bins - self.isi_bins) ** 2) / 200 + error_hist = np.mean((isi_bins - self.isi_bins) ** 2) / 600 # print("error hist: {:.2f}".format(error_hist)) # print("Burstiness: cell {:.2f}, model: {:.2f}, error: {:.2f}".format(self.burstiness, burstiness, error_bursty)) @@ -214,12 +212,12 @@ class Fitter: # error_f_zero_slopes = calculate_list_error(f_zero_slopes, self.f_zero_slopes) error_f_zero_slope_at_straight = abs(self.f_zero_slope_at_straight - f_zero_slope_at_straight) \ - / abs(self.f_zero_slope_at_straight+1 / 10) - error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 25 + / abs(self.f_zero_slope_at_straight+1) + error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 10 error_f0_curve = self.calculate_f0_curve_error(model, fi_curve_model) / 10 - error_list = [error_bf, error_vs, error_sc, error_cv, error_hist, error_bursty, + error_list = [error_vs, error_sc, error_cv, error_hist, error_bursty, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve] self.errors.append(error_list) diff --git a/helperFunctions.py b/helperFunctions.py index ed8bbbd..8b35a42 100644 --- a/helperFunctions.py +++ b/helperFunctions.py @@ -8,8 +8,8 @@ import matplotlib.pyplot as plt import time -def plot_errors(list_errors): - names = ["error_bf", "error_vs", "error_sc", "error_cv", "error_bursty", +def plot_errors(list_errors, save_path=None): + names = ["error_vs", "error_sc", "error_cv", "rms_isi_hist", "error_bursty", "error_f_inf", "error_f_inf_s", "error_f_zero", "error_f_zero_s_straight", "error_f0_curve"] data = np.array(list_errors) @@ -23,7 +23,10 @@ def plot_errors(list_errors): axes[row, col].set_title(names[i]) axes[row, col].set_yscale('log') - plt.savefig("figures/error_distributions/error_distribution_{}.png".format(time.strftime("%H:%M:%S"))) + if save_path is None: + plt.show() + else: + plt.savefig(save_path + "error_distribution.png") plt.close() diff --git a/run_Fitter.py b/run_Fitter.py index 50bdce6..e5b57e7 100644 --- a/run_Fitter.py +++ b/run_Fitter.py @@ -10,16 +10,17 @@ import time import os import argparse import numpy as np +from helperFunctions import plot_errors import multiprocessing as mp # SAVE_DIRECTORY = "./results/invivo_results/" -SAVE_DIRECTORY = "./results/test_data_isi_hist_err_added/" +SAVE_DIRECTORY = "./results/test_data/" # SAVE_DIRECTORY_BEST = "./results/invivo_best/" -SAVE_DIRECTORY_BEST = "./results/test_data_best_isi_hist/" +SAVE_DIRECTORY_BEST = "./results/test_data_best/" # [bf, vs, sc, cv, isi_hist, bursty, f_inf, f_inf_slope, f_zero, f_zero_slope, f0_curve] -ERROR_WEIGHTS = (0, 2, 2, 1, 1, 0, 1, 1, 1, 0, 1) +ERROR_WEIGHTS = (2, 2, 1, 1, 0, 1, 1, 1, 0, 1) def main(): @@ -70,6 +71,7 @@ def fit_cell_base(parameters): error = fitter.calculate_errors(model=LifacNoiseModel(res_par)) save_path = SAVE_DIRECTORY + "/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(parameters[1], sum(error)) save_fitting_run_info(parameters[0], res_par, parameters[2], plot=True, save_path=save_path) + plot_errors(fitter.errors, save_path) time2 = time.time() del fitter @@ -117,12 +119,12 @@ def iget_start_parameters(): # mem_tau, input_scaling, noise_strength, dend_tau, # expand by tau_a, delta_a ? - mem_tau_list = [0.01] - input_scaling_list = [100] - noise_strength_list = [0.03] # [0.02, 0.06] - dend_tau_list = [0.0015] - delta_a_list = [0.035, 0.065, 0.1] - tau_a_list = [0.1, 0.4] + mem_tau_list = [0.001] + input_scaling_list = [80] + noise_strength_list = [0.01] + dend_tau_list = [0.002] + delta_a_list = [0.01, 0.03, 0.065] + tau_a_list = [0.02, 0.04] ref_time_list = [0.00065, 0.0012] for mem_tau in mem_tau_list: