diff --git a/FiCurve.py b/FiCurve.py index e03bc35..1313353 100644 --- a/FiCurve.py +++ b/FiCurve.py @@ -87,6 +87,47 @@ class FICurve: fit_vars = self.f_zero_fit return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) + def get_mean_time_and_freq_traces(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_time_and_freq_traces(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_sampling_interval(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_delay(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_stimulus_start(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_stimulus_end(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_stimulus_duration(self): + return self.get_stimulus_end() - self.get_stimulus_start() + + def plot_mean_frequency_curves(self, save_path=None): + time_traces, freq_traces = self.get_time_and_freq_traces() + mean_times, mean_freqs = self.get_mean_time_and_freq_traces() + for i, sv in enumerate(self.stimulus_values): + + for j in range(len(time_traces[i])): + plt.plot(time_traces[i][j], freq_traces[i][j], color="gray", alpha=0.8) + + plt.plot(mean_times[i], mean_freqs[i], color="black") + plt.xlabel("Time [s]") + plt.ylabel("Frequency [Hz]") + + + plt.title("Mean frequency at contrast {:.2f} ({:} trials)".format(sv, len(time_traces[i]))) + if save_path is None: + plt.show() + else: + plt.savefig(save_path + "mean_frequency_contrast_{:.2f}.png".format(sv)) + plt.close() + def plot_fi_curve(self, save_path=None): min_x = min(self.stimulus_values) max_x = max(self.stimulus_values) @@ -111,7 +152,6 @@ class FICurve: if save_path is None: plt.show() else: - print("save") plt.savefig(save_path + "fi_curve.png") plt.close() @@ -173,7 +213,6 @@ class FICurve: if save_path is None: plt.show() else: - print("save") plt.savefig(save_path + "fi_curve_comparision.png") plt.close() @@ -188,8 +227,8 @@ class FICurveCellData(FICurve): super().__init__(stimulus_values) def calculate_all_frequency_points(self): - mean_frequencies = self.cell_data.get_mean_isi_frequencies() - time_axes = self.cell_data.get_time_axes_mean_frequencies() + mean_frequencies = self.cell_data.get_mean_fi_curve_isi_frequencies() + time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies() stimulus_start = self.cell_data.get_stimulus_start() stimulus_duration = self.cell_data.get_stimulus_duration() sampling_interval = self.cell_data.get_sampling_interval() @@ -217,6 +256,39 @@ class FICurveCellData(FICurve): stimulus_start, stimulus_duration, sampling_interval) self.f_inf_frequencies.append(f_infinity) + def get_mean_time_and_freq_traces(self): + return self.cell_data.get_time_axes_fi_curve_mean_frequencies(), self.cell_data.get_mean_fi_curve_isi_frequencies() + + def get_time_and_freq_traces(self): + spiketimes = self.cell_data.get_fi_spiketimes() + time_traces = [] + freq_traces = [] + for i in range(len(spiketimes)): + trial_time_traces = [] + trial_freq_traces = [] + for j in range(len(spiketimes[i])): + time, isi_freq = hF.calculate_time_and_frequency_trace(spiketimes[i][j], self.cell_data.get_sampling_interval()) + + trial_freq_traces.append(isi_freq) + trial_time_traces.append(time) + + time_traces.append(trial_time_traces) + freq_traces.append(trial_freq_traces) + + return time_traces, freq_traces + + def get_sampling_interval(self): + return self.cell_data.get_sampling_interval() + + def get_delay(self): + return self.cell_data.get_delay() + + def get_stimulus_start(self): + return self.cell_data.get_stimulus_start() + + def get_stimulus_end(self): + return self.cell_data.get_stimulus_end() + def get_f_zero_inverse_at_frequency(self, frequency): # UNUSED b_vars = self.f_zero_fit @@ -227,96 +299,42 @@ class FICurveCellData(FICurve): infty_vars = self.f_inf_fit return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1]) - # def get_fi_curve_slope_at(self, stimulus_value): - # fit_vars = self.f_zero_fit - # return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) - # - # def get_fi_curve_slope_of_straight(self): - # fit_vars = self.f_zero_fit - # return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) - - # def get_f_zero_and_f_inf_intersection(self): - # x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001) - # fit_vars = self.f_zero_fit - # f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) - # f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1]) - # - # intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten() - # # print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies]) - # if len(intersection_indicies) > 1: - # f_baseline = np.median(self.f_baseline_frequencies) - # best_dist = np.inf - # best_idx = -1 - # for idx in intersection_indicies: - # dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline) - # if dist < best_dist: - # best_dist = dist - # best_idx = idx - # - # return x_values[best_idx] - # - # elif len(intersection_indicies) == 0: - # raise ValueError("No intersection found!") - # else: - # return x_values[intersection_indicies[0]] - - # def get_fi_curve_slope_at_f_zero_intersection(self): - # x = self.get_f_zero_and_f_inf_intersection() - # fit_vars = self.f_zero_fit - # return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) - - # def plot_fi_curve(self, savepath: str = None, comp_f_baselines=None, comp_f_zeros=None, comp_f_infs=None): - # min_x = min(self.stimulus_values) - # max_x = max(self.stimulus_values) - # step = (max_x - min_x) / 5000 - # x_values = np.arange(min_x, max_x, step) - # - # plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base') - # if comp_f_baselines is not None: - # plt.plot(self.stimulus_values, comp_f_baselines, 'o', color='skyblue', label='comp_values base') - # - # plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf') - # plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values], - # color='darkgreen', label='f_inf_fit') - # if comp_f_infs is not None: - # plt.plot(self.stimulus_values, comp_f_infs, 'o', color='lime', label='comp values f_inf') - # - # plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero') - # popt = self.f_zero_fit - # plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values], - # color='red', label='f_0_fit') - # if comp_f_zeros is not None: - # plt.plot(self.stimulus_values, comp_f_zeros, 'o', color='wheat', label='comp_values f_zero') - # - # plt.legend() - # plt.ylabel("Frequency [Hz]") - # plt.xlabel("Stimulus value") - # - # if savepath is None: - # plt.show() - # else: - # print("save") - # plt.savefig(savepath + "fi_curve.png") - # plt.close() - def plot_f_point_detections(self, save_path=None): - mean_frequencies = np.array(self.cell_data.get_mean_isi_frequencies()) - time_axes = self.cell_data.get_time_axes_mean_frequencies() + mean_frequencies = np.array(self.cell_data.get_mean_fi_curve_isi_frequencies()) + time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies() + sampling_interval = self.cell_data.get_sampling_interval() + stim_start = self.cell_data.get_stimulus_start() + stim_duration = self.cell_data.get_stimulus_duration() + + for i, c in enumerate(self.stimulus_values): + time = time_axes[i] + frequency = mean_frequencies[i] + + if len(time) == 0 or min(time) > stim_start \ + or max(time) < stim_start + stim_duration: + continue + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) + ax.plot(time, frequency) + start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], stim_start, sampling_interval) + ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]), + label="f_base", color="deepskyblue") + + start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], stim_start, stim_duration, + sampling_interval) + ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]), + label="f_inf", color="limegreen") + + start_idx, end_idx = hF.time_window_detect_f_zero(time[0], stim_start, sampling_interval) + ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]), + label="f_zero", color="orange") - for i in range(len(mean_frequencies)): - fig, axes = plt.subplots(1, 1, sharex="all") - axes.plot(time_axes[i], mean_frequencies[i], label="voltage") - axes.plot((time_axes[i][0], time_axes[i][-1]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]), label="f_zero") - axes.plot((time_axes[i][0], time_axes[i][-1]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]), '--', label="f_inf") - axes.plot((time_axes[i][0], time_axes[i][-1]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]), label="f_base") - axes.set_title(str(self.stimulus_values[i])) plt.legend() + if save_path is not None: + plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c)) + else: + plt.show() - if save_path is None: - plt.show() - else: - plt.savefig(save_path + "GENERATE_NAMES.png") - plt.close() + plt.close() class FICurveModel(FICurve): @@ -374,6 +392,37 @@ class FICurveModel(FICurve): f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval) self.f_baseline_frequencies.append(f_baseline) + def get_mean_time_and_freq_traces(self): + return self.mean_time_traces, self.mean_frequency_traces + + def get_sampling_interval(self): + return self.model.get_sampling_interval() + + def get_delay(self): + return 0 + + def get_stimulus_start(self): + return self.stim_start + + def get_stimulus_end(self): + return self.stim_start + self.stim_duration + + def get_time_and_freq_traces(self): + time_traces = [] + freq_traces = [] + for v in range(len(self.stimulus_values)): + times_for_value = [] + freqs_for_value = [] + + for s in self.spiketimes_array[v]: + t, f = hF.calculate_time_and_frequency_trace(s, self.model.get_sampling_interval()) + times_for_value.append(t) + freqs_for_value.append(f) + + time_traces.append(times_for_value) + freq_traces.append(freqs_for_value) + return time_traces, freq_traces + def plot_f_point_detections(self, save_path=None): sampling_interval = self.model.get_sampling_interval()