improve starting parameters and weighting

This commit is contained in:
a.ott 2020-08-08 15:44:16 +02:00
parent 5f2597381c
commit 463fcb5997
3 changed files with 22 additions and 19 deletions

View File

@ -132,8 +132,6 @@ class Fitter:
args=(error_weights,), x0=x0, method="Nelder-Mead", args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400}) options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400})
plot_errors(self.errors)
return fmin, self.base_model.get_parameters() return fmin, self.base_model.get_parameters()
def cost_function_all(self, X, error_weights=None): def cost_function_all(self, X, error_weights=None):
@ -197,10 +195,10 @@ class Fitter:
# calculate errors with reference values # calculate errors with reference values
error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq) error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq)
error_vs = abs((vector_strength - self.vector_strength) / 0.1) error_vs = abs((vector_strength - self.vector_strength) / 0.01)
error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.2) error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.2)
error_bursty = (abs(burstiness - self.burstiness) / 0.2) error_bursty = (abs(burstiness - self.burstiness) / 0.2)
error_hist = np.mean((isi_bins - self.isi_bins) ** 2) / 200 error_hist = np.mean((isi_bins - self.isi_bins) ** 2) / 600
# print("error hist: {:.2f}".format(error_hist)) # print("error hist: {:.2f}".format(error_hist))
# print("Burstiness: cell {:.2f}, model: {:.2f}, error: {:.2f}".format(self.burstiness, burstiness, error_bursty)) # print("Burstiness: cell {:.2f}, model: {:.2f}, error: {:.2f}".format(self.burstiness, burstiness, error_bursty))
@ -214,12 +212,12 @@ class Fitter:
# error_f_zero_slopes = calculate_list_error(f_zero_slopes, self.f_zero_slopes) # error_f_zero_slopes = calculate_list_error(f_zero_slopes, self.f_zero_slopes)
error_f_zero_slope_at_straight = abs(self.f_zero_slope_at_straight - f_zero_slope_at_straight) \ error_f_zero_slope_at_straight = abs(self.f_zero_slope_at_straight - f_zero_slope_at_straight) \
/ abs(self.f_zero_slope_at_straight+1 / 10) / abs(self.f_zero_slope_at_straight+1)
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 25 error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 10
error_f0_curve = self.calculate_f0_curve_error(model, fi_curve_model) / 10 error_f0_curve = self.calculate_f0_curve_error(model, fi_curve_model) / 10
error_list = [error_bf, error_vs, error_sc, error_cv, error_hist, error_bursty, error_list = [error_vs, error_sc, error_cv, error_hist, error_bursty,
error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve] error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]
self.errors.append(error_list) self.errors.append(error_list)

View File

@ -8,8 +8,8 @@ import matplotlib.pyplot as plt
import time import time
def plot_errors(list_errors): def plot_errors(list_errors, save_path=None):
names = ["error_bf", "error_vs", "error_sc", "error_cv", "error_bursty", names = ["error_vs", "error_sc", "error_cv", "rms_isi_hist", "error_bursty",
"error_f_inf", "error_f_inf_s", "error_f_zero", "error_f_zero_s_straight", "error_f0_curve"] "error_f_inf", "error_f_inf_s", "error_f_zero", "error_f_zero_s_straight", "error_f0_curve"]
data = np.array(list_errors) data = np.array(list_errors)
@ -23,7 +23,10 @@ def plot_errors(list_errors):
axes[row, col].set_title(names[i]) axes[row, col].set_title(names[i])
axes[row, col].set_yscale('log') axes[row, col].set_yscale('log')
plt.savefig("figures/error_distributions/error_distribution_{}.png".format(time.strftime("%H:%M:%S"))) if save_path is None:
plt.show()
else:
plt.savefig(save_path + "error_distribution.png")
plt.close() plt.close()

View File

@ -10,16 +10,17 @@ import time
import os import os
import argparse import argparse
import numpy as np import numpy as np
from helperFunctions import plot_errors
import multiprocessing as mp import multiprocessing as mp
# SAVE_DIRECTORY = "./results/invivo_results/" # SAVE_DIRECTORY = "./results/invivo_results/"
SAVE_DIRECTORY = "./results/test_data_isi_hist_err_added/" SAVE_DIRECTORY = "./results/test_data/"
# SAVE_DIRECTORY_BEST = "./results/invivo_best/" # SAVE_DIRECTORY_BEST = "./results/invivo_best/"
SAVE_DIRECTORY_BEST = "./results/test_data_best_isi_hist/" SAVE_DIRECTORY_BEST = "./results/test_data_best/"
# [bf, vs, sc, cv, isi_hist, bursty, f_inf, f_inf_slope, f_zero, f_zero_slope, f0_curve] # [bf, vs, sc, cv, isi_hist, bursty, f_inf, f_inf_slope, f_zero, f_zero_slope, f0_curve]
ERROR_WEIGHTS = (0, 2, 2, 1, 1, 0, 1, 1, 1, 0, 1) ERROR_WEIGHTS = (2, 2, 1, 1, 0, 1, 1, 1, 0, 1)
def main(): def main():
@ -70,6 +71,7 @@ def fit_cell_base(parameters):
error = fitter.calculate_errors(model=LifacNoiseModel(res_par)) error = fitter.calculate_errors(model=LifacNoiseModel(res_par))
save_path = SAVE_DIRECTORY + "/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(parameters[1], sum(error)) save_path = SAVE_DIRECTORY + "/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(parameters[1], sum(error))
save_fitting_run_info(parameters[0], res_par, parameters[2], plot=True, save_path=save_path) save_fitting_run_info(parameters[0], res_par, parameters[2], plot=True, save_path=save_path)
plot_errors(fitter.errors, save_path)
time2 = time.time() time2 = time.time()
del fitter del fitter
@ -117,12 +119,12 @@ def iget_start_parameters():
# mem_tau, input_scaling, noise_strength, dend_tau, # mem_tau, input_scaling, noise_strength, dend_tau,
# expand by tau_a, delta_a ? # expand by tau_a, delta_a ?
mem_tau_list = [0.01] mem_tau_list = [0.001]
input_scaling_list = [100] input_scaling_list = [80]
noise_strength_list = [0.03] # [0.02, 0.06] noise_strength_list = [0.01]
dend_tau_list = [0.0015] dend_tau_list = [0.002]
delta_a_list = [0.035, 0.065, 0.1] delta_a_list = [0.01, 0.03, 0.065]
tau_a_list = [0.1, 0.4] tau_a_list = [0.02, 0.04]
ref_time_list = [0.00065, 0.0012] ref_time_list = [0.00065, 0.0012]
for mem_tau in mem_tau_list: for mem_tau in mem_tau_list: