diff --git a/FiCurve.py b/FiCurve.py index c7ac5ea..e03bc35 100644 --- a/FiCurve.py +++ b/FiCurve.py @@ -320,58 +320,91 @@ class FICurveCellData(FICurve): class FICurveModel(FICurve): + stim_duration = 0.5 + stim_start = 0.5 + total_simulation_time = stim_duration + 2 * stim_start def __init__(self, model, stimulus_values, eod_frequency, trials=5): self.eod_frequency = eod_frequency self.model = model self.trials = trials - self.spiketimes = [] + self.spiketimes_array = np.zeros((len(stimulus_values), trials), dtype=list) + self.mean_frequency_traces = [] + self.mean_time_traces = [] super().__init__(stimulus_values) def calculate_all_frequency_points(self): - stim_duration = 0.5 - stim_start = 0.5 - total_simulation_time = stim_duration + 2 * stim_start + sampling_interval = self.model.get_sampling_interval() self.f_inf_frequencies = [] self.f_zero_frequencies = [] self.f_baseline_frequencies = [] - for c in self.stimulus_values: - stimulus = SinusoidalStepStimulus(self.eod_frequency, c, stim_start, stim_duration) - + for i, c in enumerate(self.stimulus_values): + stimulus = SinusoidalStepStimulus(self.eod_frequency, c, self.stim_start, self.stim_duration) frequency_traces = [] time_traces = [] - for i in range(self.trials): - _, spiketimes = self.model.simulate_fast(stimulus, total_simulation_time) + for j in range(self.trials): + + _, spiketimes = self.model.simulate_fast(stimulus, self.total_simulation_time) + self.spiketimes_array[i, j] = spiketimes trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval) frequency_traces.append(trial_frequency) time_traces.append(trial_time) time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval) + self.mean_frequency_traces.append(frequency) + self.mean_time_traces.append(time) - if len(time) == 0 or min(time) > stim_start \ - or max(time) < stim_start + stim_duration: + if len(time) == 0 or min(time) > self.stim_start \ + or max(time) < self.stim_start + self.stim_duration: print("Too few spikes to calculate f_inf, f_0 and f_base") self.f_inf_frequencies.append(0) self.f_zero_frequencies.append(0) self.f_baseline_frequencies.append(0) continue - f_inf = hF.detect_f_infinity_in_freq_trace(time, frequency, stim_start, stim_duration, sampling_interval) + f_inf = hF.detect_f_infinity_in_freq_trace(time, frequency, self.stim_start, self.stim_duration, sampling_interval) self.f_inf_frequencies.append(f_inf) - f_zero = hF.detect_f_zero_in_frequency_trace(time, frequency, stim_start, sampling_interval) + f_zero = hF.detect_f_zero_in_frequency_trace(time, frequency, self.stim_start, sampling_interval) self.f_zero_frequencies.append(f_zero) - f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, stim_start, sampling_interval) + f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval) self.f_baseline_frequencies.append(f_baseline) - def plot_f_point_detections(self, save_path=None): - raise NotImplementedError("TODO sorry... " - "The model version of the FiCurve class is still missing this implementation") + sampling_interval = self.model.get_sampling_interval() + + for i, c in enumerate(self.stimulus_values): + time = self.mean_time_traces[i] + frequency = self.mean_frequency_traces[i] + + if len(time) == 0 or min(time) > self.stim_start \ + or max(time) < self.stim_start + self.stim_duration: + continue + fig, ax = plt.subplots(1, 1, figsize=(8, 8)) + ax.plot(time, frequency) + start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], self.stim_start, sampling_interval) + ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]), + label="f_base", color="deepskyblue") + + start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], self.stim_start, self.stim_duration, sampling_interval) + ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]), + label="f_inf", color="limegreen") + + start_idx, end_idx = hF.time_window_detect_f_zero(time[0], self.stim_start, sampling_interval) + ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]), + label="f_zero", color="orange") + + plt.legend() + if save_path is not None: + plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c)) + else: + plt.show() + + plt.close() def get_fi_curve_class(data, stimulus_values, eod_freq=None) -> FICurve: