from CellData import CellData from models.LIFACnoise import LifacNoiseModel from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus import numpy as np import matplotlib.pyplot as plt from warnings import warn import functions as fu import helperFunctions as hF from os.path import join, exists import pickle from sys import stderr class FICurve: def __init__(self, stimulus_values, save_dir=None): self.save_file_name = "fi_curve_values.pkl" self.stimulus_values = stimulus_values self.f_baseline_frequencies = [] self.f_inf_frequencies = [] self.f_zero_frequencies = [] # increase, offset self.f_inf_fit = [] # f_max, f_min, k, x_zero self.f_zero_fit = [] if save_dir is None: self.initialize() else: if not self.load_values(save_dir): self.initialize() self.save_values(save_dir) def initialize(self): self.calculate_all_frequency_points() self.f_inf_fit = hF.fit_clipped_line(self.stimulus_values, self.f_inf_frequencies) self.f_zero_fit = hF.fit_boltzmann(self.stimulus_values, self.f_zero_frequencies) def calculate_all_frequency_points(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_f_baseline_frequencies(self): return self.f_baseline_frequencies def get_f_inf_frequencies(self): return self.f_inf_frequencies def get_f_zero_frequencies(self): return self.f_zero_frequencies def get_f_inf_slope(self): if len(self.f_inf_fit) > 0: return self.f_inf_fit[0] def get_f_zero_fit_slope_at_straight(self): fit_vars = self.f_zero_fit return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) def get_f_zero_fit_slope_at_stimulus_value(self, stimulus_value): fit_vars = self.f_zero_fit return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) def get_f_inf_frequency_at_stimulus_value(self, stimulus_value): return fu.clipped_line(stimulus_value, self.f_inf_fit[0], self.f_inf_fit[1]) def get_f_zero_and_f_inf_intersection(self): x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001) fit_vars = self.f_zero_fit f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1]) intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten() # print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies]) if len(intersection_indicies) > 1: f_baseline = np.median(self.f_baseline_frequencies) best_dist = np.inf best_idx = -1 for idx in intersection_indicies: dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline) if dist < best_dist: best_dist = dist best_idx = idx return x_values[best_idx] elif len(intersection_indicies) == 0: raise ValueError("No intersection found!") else: return x_values[intersection_indicies[0]] def get_f_zero_fit_slope_at_f_inf_fit_intersection(self): x = self.get_f_zero_and_f_inf_intersection() fit_vars = self.f_zero_fit return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) def get_mean_time_and_freq_traces(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_time_and_freq_traces(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_sampling_interval(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_delay(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_stimulus_start(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_stimulus_end(self): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def get_stimulus_duration(self): return self.get_stimulus_end() - self.get_stimulus_start() def plot_mean_frequency_curves(self, save_path=None): time_traces, freq_traces = self.get_time_and_freq_traces() mean_times, mean_freqs = self.get_mean_time_and_freq_traces() for i, sv in enumerate(self.stimulus_values): for j in range(len(time_traces[i])): plt.plot(time_traces[i][j], freq_traces[i][j], color="gray", alpha=0.8) plt.plot(mean_times[i], mean_freqs[i], color="black") plt.xlabel("Time [s]") plt.ylabel("Frequency [Hz]") plt.title("Mean frequency at contrast {:.2f} ({:} trials)".format(sv, len(time_traces[i]))) if save_path is None: plt.show() else: plt.savefig(save_path + "mean_frequency_contrast_{:.2f}.png".format(sv)) plt.close() def plot_fi_curve(self, save_path=None): min_x = min(self.stimulus_values) max_x = max(self.stimulus_values) step = (max_x - min_x) / 5000 x_values = np.arange(min_x, max_x, step) plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base') plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf') plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values], color='darkgreen', label='f_inf_fit') plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero') popt = self.f_zero_fit plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values], color='red', label='f_0_fit') plt.legend() plt.ylabel("Frequency [Hz]") plt.xlabel("Stimulus value") if save_path is None: plt.show() else: plt.savefig(save_path + "fi_curve.png") plt.close() @staticmethod def plot_fi_curve_comparision(data_fi_curve, model_fi_curve, save_path=None): min_x = min(min(data_fi_curve.stimulus_values), min(model_fi_curve.stimulus_values)) max_x = max(max(data_fi_curve.stimulus_values), max(model_fi_curve.stimulus_values)) step = (max_x - min_x) / 5000 x_values = np.arange(min_x, max_x+step, step) fig, axes = plt.subplots(1, 3, sharex="all", sharey='all', figsize=(15, 6)) # plot baseline data_origin = (data_fi_curve, model_fi_curve) f_base_color = ("blue", "deepskyblue") f_inf_color = ("green", "limegreen") f_zero_color = ("red", "orange") for i in range(2): axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_baseline_frequencies(), color=f_base_color[i], label='f_base') axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_inf_frequencies(), 'o', color=f_inf_color[i], label='f_inf') y_values = [fu.clipped_line(x, data_origin[i].f_inf_fit[0], data_origin[i].f_inf_fit[1]) for x in x_values] axes[i].plot(x_values, y_values, color=f_inf_color[i], label='f_inf_fit') axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_zero_frequencies(), 'o', color=f_zero_color[i], label='f_zero') popt = data_origin[i].f_zero_fit axes[i].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values], color=f_zero_color[i], label='f_0_fit') axes[i].set_xlabel("Stimulus value - contrast") axes[i].legend() axes[0].set_title("cell") axes[0].set_ylabel("Frequency [Hz]") axes[1].set_title("model") median_baseline = np.median(data_fi_curve.get_f_baseline_frequencies()) axes[2].plot((min_x, max_x), (median_baseline, median_baseline), color=f_base_color[0], label="cell med base") axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_baseline_frequencies(), 'o', color=f_base_color[1], label='model base') y_values = [fu.clipped_line(x, data_fi_curve.f_inf_fit[0], data_fi_curve.f_inf_fit[1]) for x in x_values] axes[2].plot(x_values, y_values, color=f_inf_color[0], label='f_inf_fit cell') axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_inf_frequencies(), 'o', color=f_inf_color[1], label='f_inf model') popt = data_fi_curve.f_zero_fit axes[2].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values], color=f_zero_color[0], label='f_0_fit cell') axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_zero_frequencies(), 'o', color=f_zero_color[1], label='f_zero model') axes[2].set_title("cell model comparision") axes[2].set_xlabel("Stimulus value - contrast") axes[2].legend() if save_path is None: plt.show() else: plt.savefig(save_path + "fi_curve_comparision.png") plt.close() def plot_f_point_detections(self, save_path=None): raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") def save_values(self, save_directory): values = {} values["stimulus_values"] = self.stimulus_values values["f_baseline_frequencies"] = self.f_baseline_frequencies values["f_inf_frequencies"] = self.f_inf_frequencies values["f_zero_frequencies"] = self.f_zero_frequencies values["f_inf_fit"] = self.f_inf_fit values["f_zero_fit"] = self.f_zero_fit with open(join(save_directory, self.save_file_name), "wb") as file: pickle.dump(values, file) print("Fi-Curve: Values saved!") def load_values(self, save_directory): file_path = join(save_directory, self.save_file_name) if not exists(file_path): print("Fi-Curve: No file to load") return False file = open(file_path, "rb") values = pickle.load(file) if set(values["stimulus_values"]) != set(self.stimulus_values): stderr.write("Fi-Curve:load_values() - Given stimulus values are different to the loaded ones!:\n " "given: {}\n loaded: {}".format(str(self.stimulus_values), str(values["stimulus_values"]))) self.stimulus_values = values["stimulus_values"] self.f_baseline_frequencies = values["f_baseline_frequencies"] self.f_inf_frequencies = values["f_inf_frequencies"] self.f_zero_frequencies = values["f_zero_frequencies"] self.f_inf_fit = values["f_inf_fit"] self.f_zero_fit = values["f_zero_fit"] print("Fi-Curve: Values loaded!") return True class FICurveCellData(FICurve): def __init__(self, cell_data: CellData, stimulus_values, save_dir=None): self.cell_data = cell_data super().__init__(stimulus_values, save_dir) def calculate_all_frequency_points(self): mean_frequencies = self.cell_data.get_mean_fi_curve_isi_frequencies() time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies() stimulus_start = self.cell_data.get_stimulus_start() stimulus_duration = self.cell_data.get_stimulus_duration() sampling_interval = self.cell_data.get_sampling_interval() if len(mean_frequencies) == 0: warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n" "Was all_calculate_mean_isi_frequencies already called?") for i in range(len(mean_frequencies)): if time_axes[i][0] > self.cell_data.get_stimulus_start(): raise ValueError("TODO: Deal with to strongly cut frequency traces in cell data! ") # self.f_zero_frequencies.append(-1) # self.f_baseline_frequencies.append(-1) # self.f_inf_frequencies.append(-1) # continue f_zero = hF.detect_f_zero_in_frequency_trace(time_axes[i], mean_frequencies[i], stimulus_start, sampling_interval) self.f_zero_frequencies.append(f_zero) f_baseline = hF.detect_f_baseline_in_freq_trace(time_axes[i], mean_frequencies[i], stimulus_start, sampling_interval) self.f_baseline_frequencies.append(f_baseline) f_infinity = hF.detect_f_infinity_in_freq_trace(time_axes[i], mean_frequencies[i], stimulus_start, stimulus_duration, sampling_interval) self.f_inf_frequencies.append(f_infinity) def get_mean_time_and_freq_traces(self): return self.cell_data.get_time_axes_fi_curve_mean_frequencies(), self.cell_data.get_mean_fi_curve_isi_frequencies() def get_time_and_freq_traces(self): spiketimes = self.cell_data.get_fi_spiketimes() time_traces = [] freq_traces = [] for i in range(len(spiketimes)): trial_time_traces = [] trial_freq_traces = [] for j in range(len(spiketimes[i])): time, isi_freq = hF.calculate_time_and_frequency_trace(spiketimes[i][j], self.cell_data.get_sampling_interval()) trial_freq_traces.append(isi_freq) trial_time_traces.append(time) time_traces.append(trial_time_traces) freq_traces.append(trial_freq_traces) return time_traces, freq_traces def get_sampling_interval(self): return self.cell_data.get_sampling_interval() def get_delay(self): return self.cell_data.get_delay() def get_stimulus_start(self): return self.cell_data.get_stimulus_start() def get_stimulus_end(self): return self.cell_data.get_stimulus_end() def get_f_zero_inverse_at_frequency(self, frequency): # UNUSED b_vars = self.f_zero_fit return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3]) def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value): # UNUSED infty_vars = self.f_inf_fit return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1]) def plot_f_point_detections(self, save_path=None): mean_frequencies = np.array(self.cell_data.get_mean_fi_curve_isi_frequencies()) time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies() sampling_interval = self.cell_data.get_sampling_interval() stim_start = self.cell_data.get_stimulus_start() stim_duration = self.cell_data.get_stimulus_duration() for i, c in enumerate(self.stimulus_values): time = time_axes[i] frequency = mean_frequencies[i] if len(time) == 0 or min(time) > stim_start \ or max(time) < stim_start + stim_duration: continue fig, ax = plt.subplots(1, 1, figsize=(8, 8)) ax.plot(time, frequency) start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], stim_start, sampling_interval) ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]), label="f_base", color="deepskyblue") start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], stim_start, stim_duration, sampling_interval) ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]), label="f_inf", color="limegreen") start_idx, end_idx = hF.time_window_detect_f_zero(time[0], stim_start, sampling_interval) ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]), label="f_zero", color="orange") plt.legend() if save_path is not None: plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c)) else: plt.show() plt.close() class FICurveModel(FICurve): stim_duration = 0.5 stim_start = 0.5 total_simulation_time = stim_duration + 2 * stim_start def __init__(self, model, stimulus_values, eod_frequency, trials=5): self.eod_frequency = eod_frequency self.model = model self.trials = trials self.spiketimes_array = np.zeros((len(stimulus_values), trials), dtype=list) self.mean_frequency_traces = [] self.mean_time_traces = [] super().__init__(stimulus_values) def calculate_all_frequency_points(self): sampling_interval = self.model.get_sampling_interval() self.f_inf_frequencies = [] self.f_zero_frequencies = [] self.f_baseline_frequencies = [] for i, c in enumerate(self.stimulus_values): stimulus = SinusoidalStepStimulus(self.eod_frequency, c, self.stim_start, self.stim_duration) frequency_traces = [] time_traces = [] for j in range(self.trials): _, spiketimes = self.model.simulate_fast(stimulus, self.total_simulation_time) self.spiketimes_array[i, j] = spiketimes trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval) frequency_traces.append(trial_frequency) time_traces.append(trial_time) time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval) self.mean_frequency_traces.append(frequency) self.mean_time_traces.append(time) if len(time) == 0 or min(time) > self.stim_start \ or max(time) < self.stim_start + self.stim_duration: # print("Too few spikes to calculate f_inf, f_0 and f_base") self.f_inf_frequencies.append(0) self.f_zero_frequencies.append(0) self.f_baseline_frequencies.append(0) continue f_inf = hF.detect_f_infinity_in_freq_trace(time, frequency, self.stim_start, self.stim_duration, sampling_interval) self.f_inf_frequencies.append(f_inf) f_zero = hF.detect_f_zero_in_frequency_trace(time, frequency, self.stim_start, sampling_interval) self.f_zero_frequencies.append(f_zero) f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval) self.f_baseline_frequencies.append(f_baseline) def get_mean_time_and_freq_traces(self): return self.mean_time_traces, self.mean_frequency_traces def get_sampling_interval(self): return self.model.get_sampling_interval() def get_delay(self): return 0 def get_stimulus_start(self): return self.stim_start def get_stimulus_end(self): return self.stim_start + self.stim_duration def get_time_and_freq_traces(self): time_traces = [] freq_traces = [] for v in range(len(self.stimulus_values)): times_for_value = [] freqs_for_value = [] for s in self.spiketimes_array[v]: t, f = hF.calculate_time_and_frequency_trace(s, self.model.get_sampling_interval()) times_for_value.append(t) freqs_for_value.append(f) time_traces.append(times_for_value) freq_traces.append(freqs_for_value) return time_traces, freq_traces def plot_f_point_detections(self, save_path=None): sampling_interval = self.model.get_sampling_interval() for i, c in enumerate(self.stimulus_values): time = self.mean_time_traces[i] frequency = self.mean_frequency_traces[i] if len(time) == 0 or min(time) > self.stim_start \ or max(time) < self.stim_start + self.stim_duration: continue fig, ax = plt.subplots(1, 1, figsize=(8, 8)) ax.plot(time, frequency) start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], self.stim_start, sampling_interval) ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]), label="f_base", color="deepskyblue") start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], self.stim_start, self.stim_duration, sampling_interval) ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]), label="f_inf", color="limegreen") start_idx, end_idx = hF.time_window_detect_f_zero(time[0], self.stim_start, sampling_interval) ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]), label="f_zero", color="orange") plt.legend() if save_path is not None: plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c)) else: plt.show() plt.close() def get_fi_curve_class(data, stimulus_values, eod_freq=None, trials=5, save_dir=None) -> FICurve: if isinstance(data, CellData): return FICurveCellData(data, stimulus_values, save_dir) if isinstance(data, LifacNoiseModel): if eod_freq is None: raise ValueError("The FiCurveModel needs the eod variable to work") return FICurveModel(data, stimulus_values, eod_freq, trials=trials) raise ValueError("Unknown type: Cannot find corresponding Baseline class. Data was type:" + str(type(data)))