improve starting parameters and weighting

This commit is contained in:
a.ott 2020-08-08 15:44:16 +02:00
parent 5f2597381c
commit 463fcb5997
3 changed files with 22 additions and 19 deletions

View File

@ -132,8 +132,6 @@ class Fitter:
args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400})
plot_errors(self.errors)
return fmin, self.base_model.get_parameters()
def cost_function_all(self, X, error_weights=None):
@ -197,10 +195,10 @@ class Fitter:
# calculate errors with reference values
error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq)
error_vs = abs((vector_strength - self.vector_strength) / 0.1)
error_vs = abs((vector_strength - self.vector_strength) / 0.01)
error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.2)
error_bursty = (abs(burstiness - self.burstiness) / 0.2)
error_hist = np.mean((isi_bins - self.isi_bins) ** 2) / 200
error_hist = np.mean((isi_bins - self.isi_bins) ** 2) / 600
# print("error hist: {:.2f}".format(error_hist))
# print("Burstiness: cell {:.2f}, model: {:.2f}, error: {:.2f}".format(self.burstiness, burstiness, error_bursty))
@ -214,12 +212,12 @@ class Fitter:
# error_f_zero_slopes = calculate_list_error(f_zero_slopes, self.f_zero_slopes)
error_f_zero_slope_at_straight = abs(self.f_zero_slope_at_straight - f_zero_slope_at_straight) \
/ abs(self.f_zero_slope_at_straight+1 / 10)
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 25
/ abs(self.f_zero_slope_at_straight+1)
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 10
error_f0_curve = self.calculate_f0_curve_error(model, fi_curve_model) / 10
error_list = [error_bf, error_vs, error_sc, error_cv, error_hist, error_bursty,
error_list = [error_vs, error_sc, error_cv, error_hist, error_bursty,
error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]
self.errors.append(error_list)

View File

@ -8,8 +8,8 @@ import matplotlib.pyplot as plt
import time
def plot_errors(list_errors):
names = ["error_bf", "error_vs", "error_sc", "error_cv", "error_bursty",
def plot_errors(list_errors, save_path=None):
names = ["error_vs", "error_sc", "error_cv", "rms_isi_hist", "error_bursty",
"error_f_inf", "error_f_inf_s", "error_f_zero", "error_f_zero_s_straight", "error_f0_curve"]
data = np.array(list_errors)
@ -23,7 +23,10 @@ def plot_errors(list_errors):
axes[row, col].set_title(names[i])
axes[row, col].set_yscale('log')
plt.savefig("figures/error_distributions/error_distribution_{}.png".format(time.strftime("%H:%M:%S")))
if save_path is None:
plt.show()
else:
plt.savefig(save_path + "error_distribution.png")
plt.close()

View File

@ -10,16 +10,17 @@ import time
import os
import argparse
import numpy as np
from helperFunctions import plot_errors
import multiprocessing as mp
# SAVE_DIRECTORY = "./results/invivo_results/"
SAVE_DIRECTORY = "./results/test_data_isi_hist_err_added/"
SAVE_DIRECTORY = "./results/test_data/"
# SAVE_DIRECTORY_BEST = "./results/invivo_best/"
SAVE_DIRECTORY_BEST = "./results/test_data_best_isi_hist/"
SAVE_DIRECTORY_BEST = "./results/test_data_best/"
# [bf, vs, sc, cv, isi_hist, bursty, f_inf, f_inf_slope, f_zero, f_zero_slope, f0_curve]
ERROR_WEIGHTS = (0, 2, 2, 1, 1, 0, 1, 1, 1, 0, 1)
ERROR_WEIGHTS = (2, 2, 1, 1, 0, 1, 1, 1, 0, 1)
def main():
@ -70,6 +71,7 @@ def fit_cell_base(parameters):
error = fitter.calculate_errors(model=LifacNoiseModel(res_par))
save_path = SAVE_DIRECTORY + "/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(parameters[1], sum(error))
save_fitting_run_info(parameters[0], res_par, parameters[2], plot=True, save_path=save_path)
plot_errors(fitter.errors, save_path)
time2 = time.time()
del fitter
@ -117,12 +119,12 @@ def iget_start_parameters():
# mem_tau, input_scaling, noise_strength, dend_tau,
# expand by tau_a, delta_a ?
mem_tau_list = [0.01]
input_scaling_list = [100]
noise_strength_list = [0.03] # [0.02, 0.06]
dend_tau_list = [0.0015]
delta_a_list = [0.035, 0.065, 0.1]
tau_a_list = [0.1, 0.4]
mem_tau_list = [0.001]
input_scaling_list = [80]
noise_strength_list = [0.01]
dend_tau_list = [0.002]
delta_a_list = [0.01, 0.03, 0.065]
tau_a_list = [0.02, 0.04]
ref_time_list = [0.00065, 0.0012]
for mem_tau in mem_tau_list: