improve starting parameters and weighting
This commit is contained in:
parent
5f2597381c
commit
463fcb5997
12
Fitter.py
12
Fitter.py
@ -132,8 +132,6 @@ class Fitter:
|
||||
args=(error_weights,), x0=x0, method="Nelder-Mead",
|
||||
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400})
|
||||
|
||||
plot_errors(self.errors)
|
||||
|
||||
return fmin, self.base_model.get_parameters()
|
||||
|
||||
def cost_function_all(self, X, error_weights=None):
|
||||
@ -197,10 +195,10 @@ class Fitter:
|
||||
|
||||
# calculate errors with reference values
|
||||
error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq)
|
||||
error_vs = abs((vector_strength - self.vector_strength) / 0.1)
|
||||
error_vs = abs((vector_strength - self.vector_strength) / 0.01)
|
||||
error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.2)
|
||||
error_bursty = (abs(burstiness - self.burstiness) / 0.2)
|
||||
error_hist = np.mean((isi_bins - self.isi_bins) ** 2) / 200
|
||||
error_hist = np.mean((isi_bins - self.isi_bins) ** 2) / 600
|
||||
# print("error hist: {:.2f}".format(error_hist))
|
||||
# print("Burstiness: cell {:.2f}, model: {:.2f}, error: {:.2f}".format(self.burstiness, burstiness, error_bursty))
|
||||
|
||||
@ -214,12 +212,12 @@ class Fitter:
|
||||
|
||||
# error_f_zero_slopes = calculate_list_error(f_zero_slopes, self.f_zero_slopes)
|
||||
error_f_zero_slope_at_straight = abs(self.f_zero_slope_at_straight - f_zero_slope_at_straight) \
|
||||
/ abs(self.f_zero_slope_at_straight+1 / 10)
|
||||
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 25
|
||||
/ abs(self.f_zero_slope_at_straight+1)
|
||||
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 10
|
||||
|
||||
error_f0_curve = self.calculate_f0_curve_error(model, fi_curve_model) / 10
|
||||
|
||||
error_list = [error_bf, error_vs, error_sc, error_cv, error_hist, error_bursty,
|
||||
error_list = [error_vs, error_sc, error_cv, error_hist, error_bursty,
|
||||
error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]
|
||||
|
||||
self.errors.append(error_list)
|
||||
|
@ -8,8 +8,8 @@ import matplotlib.pyplot as plt
|
||||
import time
|
||||
|
||||
|
||||
def plot_errors(list_errors):
|
||||
names = ["error_bf", "error_vs", "error_sc", "error_cv", "error_bursty",
|
||||
def plot_errors(list_errors, save_path=None):
|
||||
names = ["error_vs", "error_sc", "error_cv", "rms_isi_hist", "error_bursty",
|
||||
"error_f_inf", "error_f_inf_s", "error_f_zero", "error_f_zero_s_straight", "error_f0_curve"]
|
||||
data = np.array(list_errors)
|
||||
|
||||
@ -23,7 +23,10 @@ def plot_errors(list_errors):
|
||||
axes[row, col].set_title(names[i])
|
||||
axes[row, col].set_yscale('log')
|
||||
|
||||
plt.savefig("figures/error_distributions/error_distribution_{}.png".format(time.strftime("%H:%M:%S")))
|
||||
if save_path is None:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(save_path + "error_distribution.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
|
@ -10,16 +10,17 @@ import time
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
from helperFunctions import plot_errors
|
||||
|
||||
import multiprocessing as mp
|
||||
|
||||
|
||||
# SAVE_DIRECTORY = "./results/invivo_results/"
|
||||
SAVE_DIRECTORY = "./results/test_data_isi_hist_err_added/"
|
||||
SAVE_DIRECTORY = "./results/test_data/"
|
||||
# SAVE_DIRECTORY_BEST = "./results/invivo_best/"
|
||||
SAVE_DIRECTORY_BEST = "./results/test_data_best_isi_hist/"
|
||||
SAVE_DIRECTORY_BEST = "./results/test_data_best/"
|
||||
# [bf, vs, sc, cv, isi_hist, bursty, f_inf, f_inf_slope, f_zero, f_zero_slope, f0_curve]
|
||||
ERROR_WEIGHTS = (0, 2, 2, 1, 1, 0, 1, 1, 1, 0, 1)
|
||||
ERROR_WEIGHTS = (2, 2, 1, 1, 0, 1, 1, 1, 0, 1)
|
||||
|
||||
|
||||
def main():
|
||||
@ -70,6 +71,7 @@ def fit_cell_base(parameters):
|
||||
error = fitter.calculate_errors(model=LifacNoiseModel(res_par))
|
||||
save_path = SAVE_DIRECTORY + "/" + cell_path + "/start_parameter_{:}_err_{:.2f}/".format(parameters[1], sum(error))
|
||||
save_fitting_run_info(parameters[0], res_par, parameters[2], plot=True, save_path=save_path)
|
||||
plot_errors(fitter.errors, save_path)
|
||||
time2 = time.time()
|
||||
|
||||
del fitter
|
||||
@ -117,12 +119,12 @@ def iget_start_parameters():
|
||||
# mem_tau, input_scaling, noise_strength, dend_tau,
|
||||
# expand by tau_a, delta_a ?
|
||||
|
||||
mem_tau_list = [0.01]
|
||||
input_scaling_list = [100]
|
||||
noise_strength_list = [0.03] # [0.02, 0.06]
|
||||
dend_tau_list = [0.0015]
|
||||
delta_a_list = [0.035, 0.065, 0.1]
|
||||
tau_a_list = [0.1, 0.4]
|
||||
mem_tau_list = [0.001]
|
||||
input_scaling_list = [80]
|
||||
noise_strength_list = [0.01]
|
||||
dend_tau_list = [0.002]
|
||||
delta_a_list = [0.01, 0.03, 0.065]
|
||||
tau_a_list = [0.02, 0.04]
|
||||
ref_time_list = [0.00065, 0.0012]
|
||||
|
||||
for mem_tau in mem_tau_list:
|
||||
|
Loading…
Reference in New Issue
Block a user