From 16a8da2dfc6ba79f7ca593c8801661b2f093e306 Mon Sep 17 00:00:00 2001 From: "a.ott" Date: Thu, 23 Jul 2020 10:32:31 +0200 Subject: [PATCH] correct th mean square error for not equal sampling rates --- Fitter.py | 57 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/Fitter.py b/Fitter.py index f13d19e..1715da0 100644 --- a/Fitter.py +++ b/Fitter.py @@ -125,7 +125,10 @@ class Fitter: error_weights = (0, 2, 2, 2, 1, 1, 1, 1, 0, 1) fmin = minimize(fun=self.cost_function_all, args=(error_weights,), x0=x0, method="Nelder-Mead", - options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 1200}) + options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 800}) + + print("best that was returned: {}".format(fmin.fun)) + print("best that was visited: {}".format(self.smallest_error)) return fmin, self.base_model.get_parameters() @@ -187,11 +190,6 @@ class Fitter: self.base_model.set_variable("dend_tau", X[5]) self.base_model.set_variable("refractory_period", X[6]) - - - - # TODO add tests for parameters punish impossible values (immediate high error) but also add a slope towards valid points - base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0) # find right v-offset test_model = self.base_model.get_model_copy() @@ -204,7 +202,7 @@ class Fitter: # [error_bf, error_vs, error_sc, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope] error_list = self.calculate_errors(error_weights) - print("sum: {:.2f}, ".format(sum(error_list))) + # print("sum: {:.2f}, ".format(sum(error_list))) if sum(error_list) < self.smallest_error: self.smallest_error = sum(error_list) self.best_parameters_found = X @@ -379,30 +377,43 @@ class Fitter: / abs(self.f_zero_slope_at_straight+1 / 10) error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) - # TODO - - if model.get_sampling_interval() != self.data_sampling_interval: - raise ValueError("Sampling intervals not the same!") - times, freqs = fi_curve_model.get_mean_time_and_freq_traces() freq_prediction = np.array(freqs[self.f_zero_curve_contrast_idx]) time_prediction = np.array(times[self.f_zero_curve_contrast_idx]) stimulus_start = fi_curve_model.get_stimulus_start() - time_prediction[0] length = fi_curve_model.get_stimulus_duration() / 2 - start_idx = int(stimulus_start / fi_curve_model.get_sampling_interval()) - end_idx = int((stimulus_start + length) / model.get_sampling_interval()) - - if len(time_prediction) == 0 or len(time_prediction) < end_idx or time_prediction[0] > fi_curve_model.get_stimulus_start(): + if model.get_sampling_interval() == self.data_sampling_interval: + start_idx = int(stimulus_start / fi_curve_model.get_sampling_interval()) + end_idx = int((stimulus_start + length) / model.get_sampling_interval()) + start_idx_cell = start_idx + start_idx_model = start_idx + end_idx_cell = end_idx + end_idx_model = end_idx + step_cell = 1 + step_model = 1 + else: + start_idx_cell = int(stimulus_start / self.data_sampling_interval) + start_idx_model = int(stimulus_start / fi_curve_model.get_sampling_interval()) + end_idx_cell = int((stimulus_start + length) / self.data_sampling_interval) + end_idx_model = int((stimulus_start + length) / model.get_sampling_interval()) + if round(model.get_sampling_interval() % self.data_sampling_interval, 4) == 0: + step_cell = int(model.get_sampling_interval() / self.data_sampling_interval) + step_model = 1 + else: + raise ValueError("Model sampling interval is not a multiple of data sampling interval.") + + if len(time_prediction) == 0 or len(time_prediction) < end_idx_model or time_prediction[0] > fi_curve_model.get_stimulus_start(): error_f0_curve = 200 else: - error_f0_curve = np.mean((self.f_zero_curve_freq[start_idx:end_idx] - freq_prediction[start_idx:end_idx])**2) / 100 - - # plt.plot(self.f_zero_curve_freq[start_idx:end_idx]) - # plt.plot(freq_prediction[start_idx:end_idx]) - # plt.plot((self.f_zero_curve_freq[start_idx:end_idx] - freq_prediction[start_idx:end_idx])**2) - # plt.show() - # plt.close() + data_curve = self.f_zero_curve_freq[start_idx_cell:end_idx_cell:step_cell] + model_curve = freq_prediction[start_idx_model:end_idx_model:step_model] + if len(data_curve) < len(model_curve): + model_curve = model_curve[:len(data_curve)] + elif len(model_curve) < len(data_curve): + data_curve = data_curve[:len(model_curve)] + + error_f0_curve = np.mean((model_curve - data_curve)**2) / 100 error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]