419 lines
19 KiB
Python
419 lines
19 KiB
Python
|
|
from CellData import CellData
|
|
from models.LIFACnoise import LifacNoiseModel
|
|
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from warnings import warn
|
|
import functions as fu
|
|
import helperFunctions as hF
|
|
|
|
|
|
class FICurve:
|
|
|
|
def __init__(self, stimulus_values):
|
|
self.stimulus_values = stimulus_values
|
|
|
|
self.f_baseline_frequencies = []
|
|
self.f_inf_frequencies = []
|
|
self.f_zero_frequencies = []
|
|
|
|
# increase, offset
|
|
self.f_inf_fit = []
|
|
# f_max, f_min, k, x_zero
|
|
self.f_zero_fit = []
|
|
|
|
self.initialize()
|
|
|
|
def initialize(self):
|
|
self.calculate_all_frequency_points()
|
|
self.f_inf_fit = hF.fit_clipped_line(self.stimulus_values, self.f_inf_frequencies)
|
|
self.f_zero_fit = hF.fit_boltzmann(self.stimulus_values, self.f_zero_frequencies)
|
|
|
|
def calculate_all_frequency_points(self):
|
|
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
|
|
|
def get_f_baseline_frequencies(self):
|
|
return self.f_baseline_frequencies
|
|
|
|
def get_f_inf_frequencies(self):
|
|
return self.f_inf_frequencies
|
|
|
|
def get_f_zero_frequencies(self):
|
|
return self.f_zero_frequencies
|
|
|
|
def get_f_inf_slope(self):
|
|
if len(self.f_inf_fit) > 0:
|
|
return self.f_inf_fit[0]
|
|
|
|
def get_f_zero_fit_slope_at_straight(self):
|
|
fit_vars = self.f_zero_fit
|
|
return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
|
|
|
def get_f_zero_fit_slope_at_stimulus_value(self, stimulus_value):
|
|
fit_vars = self.f_zero_fit
|
|
return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
|
|
|
def get_f_inf_frequency_at_stimulus_value(self, stimulus_value):
|
|
return fu.clipped_line(stimulus_value, self.f_inf_fit[0], self.f_inf_fit[1])
|
|
|
|
def get_f_zero_and_f_inf_intersection(self):
|
|
x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001)
|
|
fit_vars = self.f_zero_fit
|
|
f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
|
f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1])
|
|
|
|
intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten()
|
|
# print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies])
|
|
if len(intersection_indicies) > 1:
|
|
f_baseline = np.median(self.f_baseline_frequencies)
|
|
best_dist = np.inf
|
|
best_idx = -1
|
|
for idx in intersection_indicies:
|
|
dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline)
|
|
if dist < best_dist:
|
|
best_dist = dist
|
|
best_idx = idx
|
|
|
|
return x_values[best_idx]
|
|
|
|
elif len(intersection_indicies) == 0:
|
|
raise ValueError("No intersection found!")
|
|
else:
|
|
return x_values[intersection_indicies[0]]
|
|
|
|
def get_f_zero_fit_slope_at_f_inf_fit_intersection(self):
|
|
x = self.get_f_zero_and_f_inf_intersection()
|
|
fit_vars = self.f_zero_fit
|
|
return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
|
|
|
def plot_fi_curve(self, save_path=None):
|
|
min_x = min(self.stimulus_values)
|
|
max_x = max(self.stimulus_values)
|
|
step = (max_x - min_x) / 5000
|
|
x_values = np.arange(min_x, max_x, step)
|
|
|
|
plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base')
|
|
|
|
plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf')
|
|
plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values],
|
|
color='darkgreen', label='f_inf_fit')
|
|
|
|
plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero')
|
|
popt = self.f_zero_fit
|
|
plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
|
|
color='red', label='f_0_fit')
|
|
|
|
plt.legend()
|
|
plt.ylabel("Frequency [Hz]")
|
|
plt.xlabel("Stimulus value")
|
|
|
|
if save_path is None:
|
|
plt.show()
|
|
else:
|
|
print("save")
|
|
plt.savefig(save_path + "fi_curve.png")
|
|
plt.close()
|
|
|
|
@staticmethod
|
|
def plot_fi_curve_comparision(data_fi_curve, model_fi_curve, save_path=None):
|
|
min_x = min(min(data_fi_curve.stimulus_values), min(model_fi_curve.stimulus_values))
|
|
max_x = max(max(data_fi_curve.stimulus_values), max(model_fi_curve.stimulus_values))
|
|
step = (max_x - min_x) / 5000
|
|
x_values = np.arange(min_x, max_x+step, step)
|
|
|
|
fig, axes = plt.subplots(1, 3, sharex="all", sharey='all', figsize=(15, 6))
|
|
# plot baseline
|
|
data_origin = (data_fi_curve, model_fi_curve)
|
|
f_base_color = ("blue", "deepskyblue")
|
|
f_inf_color = ("green", "limegreen")
|
|
f_zero_color = ("red", "orange")
|
|
for i in range(2):
|
|
|
|
axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_baseline_frequencies(),
|
|
color=f_base_color[i], label='f_base')
|
|
|
|
axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_inf_frequencies(),
|
|
'o', color=f_inf_color[i], label='f_inf')
|
|
y_values = [fu.clipped_line(x, data_origin[i].f_inf_fit[0], data_origin[i].f_inf_fit[1]) for x in x_values]
|
|
axes[i].plot(x_values, y_values, color=f_inf_color[i], label='f_inf_fit')
|
|
|
|
axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_zero_frequencies(),
|
|
'o', color=f_zero_color[i], label='f_zero')
|
|
popt = data_origin[i].f_zero_fit
|
|
axes[i].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
|
|
color=f_zero_color[i], label='f_0_fit')
|
|
|
|
axes[i].set_xlabel("Stimulus value - contrast")
|
|
axes[i].legend()
|
|
|
|
axes[0].set_title("cell")
|
|
axes[0].set_ylabel("Frequency [Hz]")
|
|
axes[1].set_title("model")
|
|
|
|
median_baseline = np.median(data_fi_curve.get_f_baseline_frequencies())
|
|
axes[2].plot((min_x, max_x), (median_baseline, median_baseline), color=f_base_color[0], label="cell med base")
|
|
axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_baseline_frequencies(),
|
|
'o', color=f_base_color[1], label='model base')
|
|
|
|
y_values = [fu.clipped_line(x, data_fi_curve.f_inf_fit[0], data_fi_curve.f_inf_fit[1]) for x in x_values]
|
|
axes[2].plot(x_values, y_values, color=f_inf_color[0], label='f_inf_fit cell')
|
|
axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_inf_frequencies(),
|
|
'o', color=f_inf_color[1], label='f_inf model')
|
|
|
|
popt = data_fi_curve.f_zero_fit
|
|
axes[2].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
|
|
color=f_zero_color[0], label='f_0_fit cell')
|
|
axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_zero_frequencies(),
|
|
'o', color=f_zero_color[1], label='f_zero model')
|
|
axes[2].set_title("cell model comparision")
|
|
axes[2].set_xlabel("Stimulus value - contrast")
|
|
axes[2].legend()
|
|
|
|
if save_path is None:
|
|
plt.show()
|
|
else:
|
|
print("save")
|
|
plt.savefig(save_path + "fi_curve_comparision.png")
|
|
plt.close()
|
|
|
|
def plot_f_point_detections(self, save_path=None):
|
|
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
|
|
|
|
|
class FICurveCellData(FICurve):
|
|
|
|
def __init__(self, cell_data: CellData, stimulus_values):
|
|
self.cell_data = cell_data
|
|
super().__init__(stimulus_values)
|
|
|
|
def calculate_all_frequency_points(self):
|
|
mean_frequencies = self.cell_data.get_mean_isi_frequencies()
|
|
time_axes = self.cell_data.get_time_axes_mean_frequencies()
|
|
stimulus_start = self.cell_data.get_stimulus_start()
|
|
stimulus_duration = self.cell_data.get_stimulus_duration()
|
|
sampling_interval = self.cell_data.get_sampling_interval()
|
|
|
|
if len(mean_frequencies) == 0:
|
|
warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n"
|
|
"Was all_calculate_mean_isi_frequencies already called?")
|
|
|
|
for i in range(len(mean_frequencies)):
|
|
|
|
if time_axes[i][0] > self.cell_data.get_stimulus_start():
|
|
raise ValueError("TODO: Deal with to strongly cut frequency traces in cell data! ")
|
|
# self.f_zero_frequencies.append(-1)
|
|
# self.f_baseline_frequencies.append(-1)
|
|
# self.f_inf_frequencies.append(-1)
|
|
# continue
|
|
|
|
f_zero = hF.detect_f_zero_in_frequency_trace(time_axes[i], mean_frequencies[i],
|
|
stimulus_start, sampling_interval)
|
|
self.f_zero_frequencies.append(f_zero)
|
|
f_baseline = hF.detect_f_baseline_in_freq_trace(time_axes[i], mean_frequencies[i],
|
|
stimulus_start, sampling_interval)
|
|
self.f_baseline_frequencies.append(f_baseline)
|
|
f_infinity = hF.detect_f_infinity_in_freq_trace(time_axes[i], mean_frequencies[i],
|
|
stimulus_start, stimulus_duration, sampling_interval)
|
|
self.f_inf_frequencies.append(f_infinity)
|
|
|
|
def get_f_zero_inverse_at_frequency(self, frequency):
|
|
# UNUSED
|
|
b_vars = self.f_zero_fit
|
|
return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3])
|
|
|
|
def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value):
|
|
# UNUSED
|
|
infty_vars = self.f_inf_fit
|
|
return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1])
|
|
|
|
# def get_fi_curve_slope_at(self, stimulus_value):
|
|
# fit_vars = self.f_zero_fit
|
|
# return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
|
#
|
|
# def get_fi_curve_slope_of_straight(self):
|
|
# fit_vars = self.f_zero_fit
|
|
# return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
|
|
|
# def get_f_zero_and_f_inf_intersection(self):
|
|
# x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001)
|
|
# fit_vars = self.f_zero_fit
|
|
# f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
|
# f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1])
|
|
#
|
|
# intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten()
|
|
# # print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies])
|
|
# if len(intersection_indicies) > 1:
|
|
# f_baseline = np.median(self.f_baseline_frequencies)
|
|
# best_dist = np.inf
|
|
# best_idx = -1
|
|
# for idx in intersection_indicies:
|
|
# dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline)
|
|
# if dist < best_dist:
|
|
# best_dist = dist
|
|
# best_idx = idx
|
|
#
|
|
# return x_values[best_idx]
|
|
#
|
|
# elif len(intersection_indicies) == 0:
|
|
# raise ValueError("No intersection found!")
|
|
# else:
|
|
# return x_values[intersection_indicies[0]]
|
|
|
|
# def get_fi_curve_slope_at_f_zero_intersection(self):
|
|
# x = self.get_f_zero_and_f_inf_intersection()
|
|
# fit_vars = self.f_zero_fit
|
|
# return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
|
|
|
# def plot_fi_curve(self, savepath: str = None, comp_f_baselines=None, comp_f_zeros=None, comp_f_infs=None):
|
|
# min_x = min(self.stimulus_values)
|
|
# max_x = max(self.stimulus_values)
|
|
# step = (max_x - min_x) / 5000
|
|
# x_values = np.arange(min_x, max_x, step)
|
|
#
|
|
# plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base')
|
|
# if comp_f_baselines is not None:
|
|
# plt.plot(self.stimulus_values, comp_f_baselines, 'o', color='skyblue', label='comp_values base')
|
|
#
|
|
# plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf')
|
|
# plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values],
|
|
# color='darkgreen', label='f_inf_fit')
|
|
# if comp_f_infs is not None:
|
|
# plt.plot(self.stimulus_values, comp_f_infs, 'o', color='lime', label='comp values f_inf')
|
|
#
|
|
# plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero')
|
|
# popt = self.f_zero_fit
|
|
# plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
|
|
# color='red', label='f_0_fit')
|
|
# if comp_f_zeros is not None:
|
|
# plt.plot(self.stimulus_values, comp_f_zeros, 'o', color='wheat', label='comp_values f_zero')
|
|
#
|
|
# plt.legend()
|
|
# plt.ylabel("Frequency [Hz]")
|
|
# plt.xlabel("Stimulus value")
|
|
#
|
|
# if savepath is None:
|
|
# plt.show()
|
|
# else:
|
|
# print("save")
|
|
# plt.savefig(savepath + "fi_curve.png")
|
|
# plt.close()
|
|
|
|
def plot_f_point_detections(self, save_path=None):
|
|
mean_frequencies = np.array(self.cell_data.get_mean_isi_frequencies())
|
|
time_axes = self.cell_data.get_time_axes_mean_frequencies()
|
|
|
|
for i in range(len(mean_frequencies)):
|
|
fig, axes = plt.subplots(1, 1, sharex="all")
|
|
axes.plot(time_axes[i], mean_frequencies[i], label="voltage")
|
|
axes.plot((time_axes[i][0], time_axes[i][-1]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]), label="f_zero")
|
|
axes.plot((time_axes[i][0], time_axes[i][-1]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]), '--', label="f_inf")
|
|
axes.plot((time_axes[i][0], time_axes[i][-1]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]), label="f_base")
|
|
axes.set_title(str(self.stimulus_values[i]))
|
|
plt.legend()
|
|
|
|
if save_path is None:
|
|
plt.show()
|
|
else:
|
|
plt.savefig(save_path + "GENERATE_NAMES.png")
|
|
plt.close()
|
|
|
|
|
|
class FICurveModel(FICurve):
|
|
stim_duration = 0.5
|
|
stim_start = 0.5
|
|
total_simulation_time = stim_duration + 2 * stim_start
|
|
|
|
def __init__(self, model, stimulus_values, eod_frequency, trials=5):
|
|
self.eod_frequency = eod_frequency
|
|
self.model = model
|
|
self.trials = trials
|
|
self.spiketimes_array = np.zeros((len(stimulus_values), trials), dtype=list)
|
|
self.mean_frequency_traces = []
|
|
self.mean_time_traces = []
|
|
super().__init__(stimulus_values)
|
|
|
|
def calculate_all_frequency_points(self):
|
|
|
|
|
|
sampling_interval = self.model.get_sampling_interval()
|
|
self.f_inf_frequencies = []
|
|
self.f_zero_frequencies = []
|
|
self.f_baseline_frequencies = []
|
|
|
|
for i, c in enumerate(self.stimulus_values):
|
|
stimulus = SinusoidalStepStimulus(self.eod_frequency, c, self.stim_start, self.stim_duration)
|
|
frequency_traces = []
|
|
time_traces = []
|
|
for j in range(self.trials):
|
|
|
|
_, spiketimes = self.model.simulate_fast(stimulus, self.total_simulation_time)
|
|
self.spiketimes_array[i, j] = spiketimes
|
|
trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
|
|
frequency_traces.append(trial_frequency)
|
|
time_traces.append(trial_time)
|
|
|
|
time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval)
|
|
self.mean_frequency_traces.append(frequency)
|
|
self.mean_time_traces.append(time)
|
|
|
|
if len(time) == 0 or min(time) > self.stim_start \
|
|
or max(time) < self.stim_start + self.stim_duration:
|
|
print("Too few spikes to calculate f_inf, f_0 and f_base")
|
|
self.f_inf_frequencies.append(0)
|
|
self.f_zero_frequencies.append(0)
|
|
self.f_baseline_frequencies.append(0)
|
|
continue
|
|
|
|
f_inf = hF.detect_f_infinity_in_freq_trace(time, frequency, self.stim_start, self.stim_duration, sampling_interval)
|
|
self.f_inf_frequencies.append(f_inf)
|
|
|
|
f_zero = hF.detect_f_zero_in_frequency_trace(time, frequency, self.stim_start, sampling_interval)
|
|
self.f_zero_frequencies.append(f_zero)
|
|
|
|
f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval)
|
|
self.f_baseline_frequencies.append(f_baseline)
|
|
|
|
def plot_f_point_detections(self, save_path=None):
|
|
sampling_interval = self.model.get_sampling_interval()
|
|
|
|
for i, c in enumerate(self.stimulus_values):
|
|
time = self.mean_time_traces[i]
|
|
frequency = self.mean_frequency_traces[i]
|
|
|
|
if len(time) == 0 or min(time) > self.stim_start \
|
|
or max(time) < self.stim_start + self.stim_duration:
|
|
continue
|
|
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
|
|
ax.plot(time, frequency)
|
|
start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], self.stim_start, sampling_interval)
|
|
ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
|
|
label="f_base", color="deepskyblue")
|
|
|
|
start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], self.stim_start, self.stim_duration, sampling_interval)
|
|
ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
|
|
label="f_inf", color="limegreen")
|
|
|
|
start_idx, end_idx = hF.time_window_detect_f_zero(time[0], self.stim_start, sampling_interval)
|
|
ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
|
|
label="f_zero", color="orange")
|
|
|
|
plt.legend()
|
|
if save_path is not None:
|
|
plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
|
|
else:
|
|
plt.show()
|
|
|
|
plt.close()
|
|
|
|
|
|
def get_fi_curve_class(data, stimulus_values, eod_freq=None) -> FICurve:
|
|
if isinstance(data, CellData):
|
|
return FICurveCellData(data, stimulus_values)
|
|
if isinstance(data, LifacNoiseModel):
|
|
if eod_freq is None:
|
|
raise ValueError("The FiCurveModel needs the eod variable to work")
|
|
return FICurveModel(data, stimulus_values, eod_freq)
|
|
|
|
raise ValueError("Unknown type: Cannot find corresponding Baseline class. Data was type:" + str(type(data)))
|