add plotting function for f point detections (model)

This commit is contained in:
a.ott 2020-05-18 12:18:47 +02:00
parent 49abc78983
commit a41de94961

View File

@ -320,58 +320,91 @@ class FICurveCellData(FICurve):
class FICurveModel(FICurve):
stim_duration = 0.5
stim_start = 0.5
total_simulation_time = stim_duration + 2 * stim_start
def __init__(self, model, stimulus_values, eod_frequency, trials=5):
self.eod_frequency = eod_frequency
self.model = model
self.trials = trials
self.spiketimes = []
self.spiketimes_array = np.zeros((len(stimulus_values), trials), dtype=list)
self.mean_frequency_traces = []
self.mean_time_traces = []
super().__init__(stimulus_values)
def calculate_all_frequency_points(self):
stim_duration = 0.5
stim_start = 0.5
total_simulation_time = stim_duration + 2 * stim_start
sampling_interval = self.model.get_sampling_interval()
self.f_inf_frequencies = []
self.f_zero_frequencies = []
self.f_baseline_frequencies = []
for c in self.stimulus_values:
stimulus = SinusoidalStepStimulus(self.eod_frequency, c, stim_start, stim_duration)
for i, c in enumerate(self.stimulus_values):
stimulus = SinusoidalStepStimulus(self.eod_frequency, c, self.stim_start, self.stim_duration)
frequency_traces = []
time_traces = []
for i in range(self.trials):
_, spiketimes = self.model.simulate_fast(stimulus, total_simulation_time)
for j in range(self.trials):
_, spiketimes = self.model.simulate_fast(stimulus, self.total_simulation_time)
self.spiketimes_array[i, j] = spiketimes
trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
frequency_traces.append(trial_frequency)
time_traces.append(trial_time)
time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval)
self.mean_frequency_traces.append(frequency)
self.mean_time_traces.append(time)
if len(time) == 0 or min(time) > stim_start \
or max(time) < stim_start + stim_duration:
if len(time) == 0 or min(time) > self.stim_start \
or max(time) < self.stim_start + self.stim_duration:
print("Too few spikes to calculate f_inf, f_0 and f_base")
self.f_inf_frequencies.append(0)
self.f_zero_frequencies.append(0)
self.f_baseline_frequencies.append(0)
continue
f_inf = hF.detect_f_infinity_in_freq_trace(time, frequency, stim_start, stim_duration, sampling_interval)
f_inf = hF.detect_f_infinity_in_freq_trace(time, frequency, self.stim_start, self.stim_duration, sampling_interval)
self.f_inf_frequencies.append(f_inf)
f_zero = hF.detect_f_zero_in_frequency_trace(time, frequency, stim_start, sampling_interval)
f_zero = hF.detect_f_zero_in_frequency_trace(time, frequency, self.stim_start, sampling_interval)
self.f_zero_frequencies.append(f_zero)
f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, stim_start, sampling_interval)
f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval)
self.f_baseline_frequencies.append(f_baseline)
def plot_f_point_detections(self, save_path=None):
raise NotImplementedError("TODO sorry... "
"The model version of the FiCurve class is still missing this implementation")
sampling_interval = self.model.get_sampling_interval()
for i, c in enumerate(self.stimulus_values):
time = self.mean_time_traces[i]
frequency = self.mean_frequency_traces[i]
if len(time) == 0 or min(time) > self.stim_start \
or max(time) < self.stim_start + self.stim_duration:
continue
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.plot(time, frequency)
start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], self.stim_start, sampling_interval)
ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
label="f_base", color="deepskyblue")
start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], self.stim_start, self.stim_duration, sampling_interval)
ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
label="f_inf", color="limegreen")
start_idx, end_idx = hF.time_window_detect_f_zero(time[0], self.stim_start, sampling_interval)
ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
label="f_zero", color="orange")
plt.legend()
if save_path is not None:
plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
else:
plt.show()
plt.close()
def get_fi_curve_class(data, stimulus_values, eod_freq=None) -> FICurve: