From eeb43fd7fcd5e82a4f0e5694b0fda43da96ce133 Mon Sep 17 00:00:00 2001 From: "a.ott" Date: Sat, 1 Aug 2020 12:05:50 +0200 Subject: [PATCH] add error of mean-square with the isi bins --- Fitter.py | 126 +++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 102 insertions(+), 24 deletions(-) diff --git a/Fitter.py b/Fitter.py index 9a2d67a..7d1e325 100644 --- a/Fitter.py +++ b/Fitter.py @@ -32,6 +32,7 @@ class Fitter: self.sc_max_lag = 2 # values to be replicated: + self.isi_bins = np.array(0) self.baseline_freq = 0 self.vector_strength = -1 self.serial_correlation = [] @@ -64,20 +65,26 @@ class Fitter: data_baseline = get_baseline_class(cell_data) data_baseline.load_values(cell_data.get_data_path()) self.baseline_freq = data_baseline.get_baseline_frequency() + self.isi_bins = calculate_histogram_bins(data_baseline.get_interspike_intervals()) + # plt.close() + # plt.plot(self.isi_bins) + # plt.show() + # plt.close() self.vector_strength = data_baseline.get_vector_strength() self.serial_correlation = data_baseline.get_serial_correlation(self.sc_max_lag) self.coefficient_of_variation = data_baseline.get_coefficient_of_variation() self.burstiness = data_baseline.get_burstiness() - fi_curve = get_fi_curve_class(cell_data, cell_data.get_fi_contrasts(), save_dir=cell_data.get_data_path()) - self.f_inf_slope = fi_curve.get_f_inf_slope() contrasts = np.array(cell_data.get_fi_contrasts()) + fi_curve = get_fi_curve_class(cell_data, contrasts, save_dir=cell_data.get_data_path()) + self.f_inf_slope = fi_curve.get_f_inf_slope() + if self.f_inf_slope < 0: contrasts = contrasts * -1 # print("old contrasts:", cell_data.get_fi_contrasts()) # print("new contrasts:", contrasts) - contrasts = sorted(contrasts) - fi_curve = get_fi_curve_class(cell_data, contrasts) + + fi_curve = get_fi_curve_class(cell_data, contrasts, save_dir=cell_data.get_data_path()) self.fi_contrasts = fi_curve.stimulus_values self.f_inf_values = fi_curve.f_inf_frequencies @@ -121,9 +128,6 @@ class Fitter: # error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty, # error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve] - if error_weights is None: - error_weights = (1, 2, 2, 2, 2, 1, 1, 1, 0, 1) - fmin = minimize(fun=self.cost_function_all, args=(error_weights,), x0=x0, method="Nelder-Mead", options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400}) @@ -150,13 +154,13 @@ class Fitter: # find right v-offset test_model = self.base_model.get_model_copy() test_model.set_variable("noise_strength", 0) - time1 = time.time() + + # time1 = time.time() v_offset = test_model.find_v_offset(self.baseline_freq, base_stimulus) self.base_model.set_variable("v_offset", v_offset) - time2 = time.time() + # time2 = time.time() # print("time taken for finding v_offset: {:.2f}s".format(time2-time1)) - # [error_bf, error_vs, error_sc, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope] error_list = self.calculate_errors(error_weights) # print("sum: {:.2f}, ".format(sum(error_list))) if sum(error_list) < self.smallest_error: @@ -168,18 +172,18 @@ class Fitter: if model is None: model = self.base_model - time1 = time.time() + # time1 = time.time() model_baseline = get_baseline_class(model, self.eod_freq, trials=3) baseline_freq = model_baseline.get_baseline_frequency() vector_strength = model_baseline.get_vector_strength() serial_correlation = model_baseline.get_serial_correlation(self.sc_max_lag) coefficient_of_variation = model_baseline.get_coefficient_of_variation() burstiness = model_baseline.get_burstiness() - time2 = time.time() - + # time2 = time.time() + isi_bins = calculate_histogram_bins(model_baseline.get_interspike_intervals()) # print("Time taken for all baseline parameters: {:.2f}".format(time2-time1)) - time1 = time.time() + # time1 = time.time() fi_curve_model = get_fi_curve_class(model, self.fi_contrasts, self.eod_freq, trials=8) f_zeros = fi_curve_model.get_f_zero_frequencies() f_infinities = fi_curve_model.get_f_inf_frequencies() @@ -187,15 +191,17 @@ class Fitter: # f_zero_slopes = [fi_curve_model.get_f_zero_fit_slope_at_stimulus_value(x) for x in self.fi_contrasts] f_zero_slope_at_straight = fi_curve_model.get_f_zero_fit_slope_at_stimulus_value(self.f_zero_straight_contrast) - time2 = time.time() + # time2 = time.time() # print("Time taken for all fi-curve parameters: {:.2f}".format(time2 - time1)) # calculate errors with reference values error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq) error_vs = abs((vector_strength - self.vector_strength) / 0.1) - error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.1) + error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.2) error_bursty = (abs(burstiness - self.burstiness) / 0.2) + error_hist = np.mean((isi_bins - self.isi_bins) ** 2) / 200 + # print("error hist: {:.2f}".format(error_hist)) # print("Burstiness: cell {:.2f}, model: {:.2f}, error: {:.2f}".format(self.burstiness, burstiness, error_bursty)) error_sc = 0 @@ -209,11 +215,11 @@ class Fitter: # error_f_zero_slopes = calculate_list_error(f_zero_slopes, self.f_zero_slopes) error_f_zero_slope_at_straight = abs(self.f_zero_slope_at_straight - f_zero_slope_at_straight) \ / abs(self.f_zero_slope_at_straight+1 / 10) - error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) + error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 25 - error_f0_curve = self.calculate_f0_curve_error(model, fi_curve_model) + error_f0_curve = self.calculate_f0_curve_error(model, fi_curve_model) / 10 - error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty, + error_list = [error_bf, error_vs, error_sc, error_cv, error_hist, error_bursty, error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve] self.errors.append(error_list) @@ -232,18 +238,20 @@ class Fitter: return error_list def calculate_f0_curve_error(self, model, fi_curve_model): - buffer = 0.05 + buffer = 0.00 test_duration = 0.05 - # prepare model frequency curve: times, freqs = fi_curve_model.get_mean_time_and_freq_traces() freq_prediction = np.array(freqs[self.f_zero_curve_contrast_idx]) time_prediction = np.array(times[self.f_zero_curve_contrast_idx]) + + if len(time_prediction) == 0: + return 200 stimulus_start = fi_curve_model.get_stimulus_start() - time_prediction[0] model_start_idx = int((stimulus_start - buffer) / fi_curve_model.get_sampling_interval()) - model_end_idx = int((stimulus_start - buffer + test_duration) / model.get_sampling_interval()) + model_end_idx = int((stimulus_start + buffer + test_duration) / model.get_sampling_interval()) if len(time_prediction) == 0 or len(time_prediction) < model_end_idx \ or time_prediction[0] > fi_curve_model.get_stimulus_start(): @@ -256,7 +264,7 @@ class Fitter: stimulus_start = self.recording_times[1] - self.f_zero_curve_time[0] cell_start_idx = int((stimulus_start - buffer) / self.data_sampling_interval) - cell_end_idx = int((stimulus_start - buffer + test_duration) / self.data_sampling_interval) + cell_end_idx = int((stimulus_start + buffer + test_duration) / self.data_sampling_interval) if round(model.get_sampling_interval() % self.data_sampling_interval, 4) == 0: step_cell = int(round(model.get_sampling_interval() / self.data_sampling_interval)) @@ -264,7 +272,6 @@ class Fitter: raise ValueError("Model sampling interval is not a multiple of data sampling interval.") cell_curve = self.f_zero_curve_freq[cell_start_idx:cell_end_idx:step_cell] - # plt.close() # plt.plot(cell_curve) # plt.plot(model_curve) @@ -280,6 +287,69 @@ class Fitter: return error_f0_curve + def calculate_f0_curve_error_new(self, model, fi_curve_model): + buffer = 0.05 + test_duration = 0.05 + + times, freqs = fi_curve_model.get_mean_time_and_freq_traces() + freq_prediction = np.array(freqs[self.f_zero_curve_contrast_idx]) + time_prediction = np.array(times[self.f_zero_curve_contrast_idx]) + + if len(time_prediction) == 0: + return 200 + stimulus_start = fi_curve_model.get_stimulus_start() - time_prediction[0] + + model_start_idx = int((stimulus_start - buffer) / model.get_sampling_interval()) + model_end_idx = int((stimulus_start + buffer + test_duration) / model.get_sampling_interval()) + + if len(time_prediction) == 0 or len(time_prediction) < model_end_idx \ + or time_prediction[0] > fi_curve_model.get_stimulus_start(): + error_f0_curve = 200 + return error_f0_curve + + model_curve = np.array(freq_prediction[model_start_idx:model_end_idx]) + + # prepare cell frequency_curve: + + stimulus_start = self.recording_times[1] - self.f_zero_curve_time[0] + cell_start_idx = int((stimulus_start - buffer) / self.data_sampling_interval) + cell_end_idx = int((stimulus_start - buffer + test_duration) / self.data_sampling_interval) + + if round(model.get_sampling_interval() % self.data_sampling_interval, 4) == 0: + step_cell = int(round(model.get_sampling_interval() / self.data_sampling_interval)) + else: + raise ValueError("Model sampling interval is not a multiple of data sampling interval.") + + cell_curve = self.f_zero_curve_freq[cell_start_idx:cell_end_idx:step_cell] + cell_time = self.f_zero_curve_time[cell_start_idx:cell_end_idx:step_cell] + cell_curve_std = np.std(self.f_zero_curve_freq) + model_curve_std = np.std(freq_prediction) + + model_limit = self.baseline_freq + model_curve_std + cell_limit = self.baseline_freq + cell_curve_std + + cell_full_precicion = np.array(self.f_zero_curve_freq[cell_start_idx:cell_end_idx]) + cell_points_above = cell_full_precicion > cell_limit + cell_area_above = sum(cell_full_precicion[cell_points_above]) * self.data_sampling_interval + + model_points_above = model_curve > model_limit + model_area_above = sum(model_curve[model_points_above]) * model.get_sampling_interval() + + # plt.close() + # plt.plot(cell_time, cell_curve, color="blue") + # plt.plot((cell_time[0], cell_time[-1]), (cell_limit, cell_limit), + # color="lightblue", label="area: {:.2f}".format(cell_area_above)) + # + # plt.plot(time_prediction[model_start_idx:model_end_idx], model_curve, color="orange") + # plt.plot((time_prediction[model_start_idx], time_prediction[model_end_idx]), (model_limit, model_limit), + # color="red", label="area: {:.2f}".format(model_area_above)) + # plt.legend() + # plt.title("Error: {:.2f}".format(abs(model_area_above - cell_area_above) / 0.02)) + # plt.savefig("./figures/f_zero_curve_error_{}.png".format(time.strftime("%H:%M:%S"))) + # plt.close() + + return abs(model_area_above - cell_area_above) + def calculate_list_error(fit, reference): error = 0 @@ -290,6 +360,14 @@ def calculate_list_error(fit, reference): return norm_error +def calculate_histogram_bins(isis): + isis = np.array(isis) * 1000 + step = 0.1 + bins = np.arange(0, 50, step) + + counts = np.array([np.sum((isis >= b) & (isis < b+0.1)) for b in bins]) + return counts + def normed_quadratic_freq_error(fit, ref, factor=2): return (abs(fit-ref)/factor)**2 / ref