correct th mean square error for not equal sampling rates

This commit is contained in:
a.ott 2020-07-23 10:32:31 +02:00
parent 365c6ae825
commit 16a8da2dfc

View File

@ -125,7 +125,10 @@ class Fitter:
error_weights = (0, 2, 2, 2, 1, 1, 1, 1, 0, 1)
fmin = minimize(fun=self.cost_function_all,
args=(error_weights,), x0=x0, method="Nelder-Mead",
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 1200})
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 800})
print("best that was returned: {}".format(fmin.fun))
print("best that was visited: {}".format(self.smallest_error))
return fmin, self.base_model.get_parameters()
@ -187,11 +190,6 @@ class Fitter:
self.base_model.set_variable("dend_tau", X[5])
self.base_model.set_variable("refractory_period", X[6])
# TODO add tests for parameters punish impossible values (immediate high error) but also add a slope towards valid points
base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0)
# find right v-offset
test_model = self.base_model.get_model_copy()
@ -204,7 +202,7 @@ class Fitter:
# [error_bf, error_vs, error_sc, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
error_list = self.calculate_errors(error_weights)
print("sum: {:.2f}, ".format(sum(error_list)))
# print("sum: {:.2f}, ".format(sum(error_list)))
if sum(error_list) < self.smallest_error:
self.smallest_error = sum(error_list)
self.best_parameters_found = X
@ -379,30 +377,43 @@ class Fitter:
/ abs(self.f_zero_slope_at_straight+1 / 10)
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values)
# TODO
if model.get_sampling_interval() != self.data_sampling_interval:
raise ValueError("Sampling intervals not the same!")
times, freqs = fi_curve_model.get_mean_time_and_freq_traces()
freq_prediction = np.array(freqs[self.f_zero_curve_contrast_idx])
time_prediction = np.array(times[self.f_zero_curve_contrast_idx])
stimulus_start = fi_curve_model.get_stimulus_start() - time_prediction[0]
length = fi_curve_model.get_stimulus_duration() / 2
start_idx = int(stimulus_start / fi_curve_model.get_sampling_interval())
end_idx = int((stimulus_start + length) / model.get_sampling_interval())
if len(time_prediction) == 0 or len(time_prediction) < end_idx or time_prediction[0] > fi_curve_model.get_stimulus_start():
if model.get_sampling_interval() == self.data_sampling_interval:
start_idx = int(stimulus_start / fi_curve_model.get_sampling_interval())
end_idx = int((stimulus_start + length) / model.get_sampling_interval())
start_idx_cell = start_idx
start_idx_model = start_idx
end_idx_cell = end_idx
end_idx_model = end_idx
step_cell = 1
step_model = 1
else:
start_idx_cell = int(stimulus_start / self.data_sampling_interval)
start_idx_model = int(stimulus_start / fi_curve_model.get_sampling_interval())
end_idx_cell = int((stimulus_start + length) / self.data_sampling_interval)
end_idx_model = int((stimulus_start + length) / model.get_sampling_interval())
if round(model.get_sampling_interval() % self.data_sampling_interval, 4) == 0:
step_cell = int(model.get_sampling_interval() / self.data_sampling_interval)
step_model = 1
else:
raise ValueError("Model sampling interval is not a multiple of data sampling interval.")
if len(time_prediction) == 0 or len(time_prediction) < end_idx_model or time_prediction[0] > fi_curve_model.get_stimulus_start():
error_f0_curve = 200
else:
error_f0_curve = np.mean((self.f_zero_curve_freq[start_idx:end_idx] - freq_prediction[start_idx:end_idx])**2) / 100
# plt.plot(self.f_zero_curve_freq[start_idx:end_idx])
# plt.plot(freq_prediction[start_idx:end_idx])
# plt.plot((self.f_zero_curve_freq[start_idx:end_idx] - freq_prediction[start_idx:end_idx])**2)
# plt.show()
# plt.close()
data_curve = self.f_zero_curve_freq[start_idx_cell:end_idx_cell:step_cell]
model_curve = freq_prediction[start_idx_model:end_idx_model:step_model]
if len(data_curve) < len(model_curve):
model_curve = model_curve[:len(data_curve)]
elif len(model_curve) < len(data_curve):
data_curve = data_curve[:len(model_curve)]
error_f0_curve = np.mean((model_curve - data_curve)**2) / 100
error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty,
error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]