correct th mean square error for not equal sampling rates
This commit is contained in:
parent
365c6ae825
commit
16a8da2dfc
57
Fitter.py
57
Fitter.py
@ -125,7 +125,10 @@ class Fitter:
|
|||||||
error_weights = (0, 2, 2, 2, 1, 1, 1, 1, 0, 1)
|
error_weights = (0, 2, 2, 2, 1, 1, 1, 1, 0, 1)
|
||||||
fmin = minimize(fun=self.cost_function_all,
|
fmin = minimize(fun=self.cost_function_all,
|
||||||
args=(error_weights,), x0=x0, method="Nelder-Mead",
|
args=(error_weights,), x0=x0, method="Nelder-Mead",
|
||||||
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 1200})
|
options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 800})
|
||||||
|
|
||||||
|
print("best that was returned: {}".format(fmin.fun))
|
||||||
|
print("best that was visited: {}".format(self.smallest_error))
|
||||||
|
|
||||||
return fmin, self.base_model.get_parameters()
|
return fmin, self.base_model.get_parameters()
|
||||||
|
|
||||||
@ -187,11 +190,6 @@ class Fitter:
|
|||||||
self.base_model.set_variable("dend_tau", X[5])
|
self.base_model.set_variable("dend_tau", X[5])
|
||||||
self.base_model.set_variable("refractory_period", X[6])
|
self.base_model.set_variable("refractory_period", X[6])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO add tests for parameters punish impossible values (immediate high error) but also add a slope towards valid points
|
|
||||||
|
|
||||||
base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0)
|
base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0)
|
||||||
# find right v-offset
|
# find right v-offset
|
||||||
test_model = self.base_model.get_model_copy()
|
test_model = self.base_model.get_model_copy()
|
||||||
@ -204,7 +202,7 @@ class Fitter:
|
|||||||
|
|
||||||
# [error_bf, error_vs, error_sc, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
|
# [error_bf, error_vs, error_sc, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope]
|
||||||
error_list = self.calculate_errors(error_weights)
|
error_list = self.calculate_errors(error_weights)
|
||||||
print("sum: {:.2f}, ".format(sum(error_list)))
|
# print("sum: {:.2f}, ".format(sum(error_list)))
|
||||||
if sum(error_list) < self.smallest_error:
|
if sum(error_list) < self.smallest_error:
|
||||||
self.smallest_error = sum(error_list)
|
self.smallest_error = sum(error_list)
|
||||||
self.best_parameters_found = X
|
self.best_parameters_found = X
|
||||||
@ -379,30 +377,43 @@ class Fitter:
|
|||||||
/ abs(self.f_zero_slope_at_straight+1 / 10)
|
/ abs(self.f_zero_slope_at_straight+1 / 10)
|
||||||
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values)
|
error_f_zero = calculate_list_error(f_zeros, self.f_zero_values)
|
||||||
|
|
||||||
# TODO
|
|
||||||
|
|
||||||
if model.get_sampling_interval() != self.data_sampling_interval:
|
|
||||||
raise ValueError("Sampling intervals not the same!")
|
|
||||||
|
|
||||||
times, freqs = fi_curve_model.get_mean_time_and_freq_traces()
|
times, freqs = fi_curve_model.get_mean_time_and_freq_traces()
|
||||||
freq_prediction = np.array(freqs[self.f_zero_curve_contrast_idx])
|
freq_prediction = np.array(freqs[self.f_zero_curve_contrast_idx])
|
||||||
time_prediction = np.array(times[self.f_zero_curve_contrast_idx])
|
time_prediction = np.array(times[self.f_zero_curve_contrast_idx])
|
||||||
stimulus_start = fi_curve_model.get_stimulus_start() - time_prediction[0]
|
stimulus_start = fi_curve_model.get_stimulus_start() - time_prediction[0]
|
||||||
length = fi_curve_model.get_stimulus_duration() / 2
|
length = fi_curve_model.get_stimulus_duration() / 2
|
||||||
|
|
||||||
start_idx = int(stimulus_start / fi_curve_model.get_sampling_interval())
|
if model.get_sampling_interval() == self.data_sampling_interval:
|
||||||
end_idx = int((stimulus_start + length) / model.get_sampling_interval())
|
start_idx = int(stimulus_start / fi_curve_model.get_sampling_interval())
|
||||||
|
end_idx = int((stimulus_start + length) / model.get_sampling_interval())
|
||||||
if len(time_prediction) == 0 or len(time_prediction) < end_idx or time_prediction[0] > fi_curve_model.get_stimulus_start():
|
start_idx_cell = start_idx
|
||||||
|
start_idx_model = start_idx
|
||||||
|
end_idx_cell = end_idx
|
||||||
|
end_idx_model = end_idx
|
||||||
|
step_cell = 1
|
||||||
|
step_model = 1
|
||||||
|
else:
|
||||||
|
start_idx_cell = int(stimulus_start / self.data_sampling_interval)
|
||||||
|
start_idx_model = int(stimulus_start / fi_curve_model.get_sampling_interval())
|
||||||
|
end_idx_cell = int((stimulus_start + length) / self.data_sampling_interval)
|
||||||
|
end_idx_model = int((stimulus_start + length) / model.get_sampling_interval())
|
||||||
|
if round(model.get_sampling_interval() % self.data_sampling_interval, 4) == 0:
|
||||||
|
step_cell = int(model.get_sampling_interval() / self.data_sampling_interval)
|
||||||
|
step_model = 1
|
||||||
|
else:
|
||||||
|
raise ValueError("Model sampling interval is not a multiple of data sampling interval.")
|
||||||
|
|
||||||
|
if len(time_prediction) == 0 or len(time_prediction) < end_idx_model or time_prediction[0] > fi_curve_model.get_stimulus_start():
|
||||||
error_f0_curve = 200
|
error_f0_curve = 200
|
||||||
else:
|
else:
|
||||||
error_f0_curve = np.mean((self.f_zero_curve_freq[start_idx:end_idx] - freq_prediction[start_idx:end_idx])**2) / 100
|
data_curve = self.f_zero_curve_freq[start_idx_cell:end_idx_cell:step_cell]
|
||||||
|
model_curve = freq_prediction[start_idx_model:end_idx_model:step_model]
|
||||||
# plt.plot(self.f_zero_curve_freq[start_idx:end_idx])
|
if len(data_curve) < len(model_curve):
|
||||||
# plt.plot(freq_prediction[start_idx:end_idx])
|
model_curve = model_curve[:len(data_curve)]
|
||||||
# plt.plot((self.f_zero_curve_freq[start_idx:end_idx] - freq_prediction[start_idx:end_idx])**2)
|
elif len(model_curve) < len(data_curve):
|
||||||
# plt.show()
|
data_curve = data_curve[:len(model_curve)]
|
||||||
# plt.close()
|
|
||||||
|
error_f0_curve = np.mean((model_curve - data_curve)**2) / 100
|
||||||
|
|
||||||
error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty,
|
error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty,
|
||||||
error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]
|
error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]
|
||||||
|
Loading…
Reference in New Issue
Block a user