P-unit_model/FiCurve.py

419 lines
19 KiB
Python

from CellData import CellData
from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import numpy as np
import matplotlib.pyplot as plt
from warnings import warn
import functions as fu
import helperFunctions as hF
class FICurve:
def __init__(self, stimulus_values):
self.stimulus_values = stimulus_values
self.f_baseline_frequencies = []
self.f_inf_frequencies = []
self.f_zero_frequencies = []
# increase, offset
self.f_inf_fit = []
# f_max, f_min, k, x_zero
self.f_zero_fit = []
self.initialize()
def initialize(self):
self.calculate_all_frequency_points()
self.f_inf_fit = hF.fit_clipped_line(self.stimulus_values, self.f_inf_frequencies)
self.f_zero_fit = hF.fit_boltzmann(self.stimulus_values, self.f_zero_frequencies)
def calculate_all_frequency_points(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_f_baseline_frequencies(self):
return self.f_baseline_frequencies
def get_f_inf_frequencies(self):
return self.f_inf_frequencies
def get_f_zero_frequencies(self):
return self.f_zero_frequencies
def get_f_inf_slope(self):
if len(self.f_inf_fit) > 0:
return self.f_inf_fit[0]
def get_f_zero_fit_slope_at_straight(self):
fit_vars = self.f_zero_fit
return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
def get_f_zero_fit_slope_at_stimulus_value(self, stimulus_value):
fit_vars = self.f_zero_fit
return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
def get_f_inf_frequency_at_stimulus_value(self, stimulus_value):
return fu.clipped_line(stimulus_value, self.f_inf_fit[0], self.f_inf_fit[1])
def get_f_zero_and_f_inf_intersection(self):
x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001)
fit_vars = self.f_zero_fit
f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1])
intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten()
# print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies])
if len(intersection_indicies) > 1:
f_baseline = np.median(self.f_baseline_frequencies)
best_dist = np.inf
best_idx = -1
for idx in intersection_indicies:
dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline)
if dist < best_dist:
best_dist = dist
best_idx = idx
return x_values[best_idx]
elif len(intersection_indicies) == 0:
raise ValueError("No intersection found!")
else:
return x_values[intersection_indicies[0]]
def get_f_zero_fit_slope_at_f_inf_fit_intersection(self):
x = self.get_f_zero_and_f_inf_intersection()
fit_vars = self.f_zero_fit
return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
def plot_fi_curve(self, save_path=None):
min_x = min(self.stimulus_values)
max_x = max(self.stimulus_values)
step = (max_x - min_x) / 5000
x_values = np.arange(min_x, max_x, step)
plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base')
plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf')
plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values],
color='darkgreen', label='f_inf_fit')
plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero')
popt = self.f_zero_fit
plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
color='red', label='f_0_fit')
plt.legend()
plt.ylabel("Frequency [Hz]")
plt.xlabel("Stimulus value")
if save_path is None:
plt.show()
else:
print("save")
plt.savefig(save_path + "fi_curve.png")
plt.close()
@staticmethod
def plot_fi_curve_comparision(data_fi_curve, model_fi_curve, save_path=None):
min_x = min(min(data_fi_curve.stimulus_values), min(model_fi_curve.stimulus_values))
max_x = max(max(data_fi_curve.stimulus_values), max(model_fi_curve.stimulus_values))
step = (max_x - min_x) / 5000
x_values = np.arange(min_x, max_x+step, step)
fig, axes = plt.subplots(1, 3, sharex="all", sharey='all', figsize=(15, 6))
# plot baseline
data_origin = (data_fi_curve, model_fi_curve)
f_base_color = ("blue", "deepskyblue")
f_inf_color = ("green", "limegreen")
f_zero_color = ("red", "orange")
for i in range(2):
axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_baseline_frequencies(),
color=f_base_color[i], label='f_base')
axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_inf_frequencies(),
'o', color=f_inf_color[i], label='f_inf')
y_values = [fu.clipped_line(x, data_origin[i].f_inf_fit[0], data_origin[i].f_inf_fit[1]) for x in x_values]
axes[i].plot(x_values, y_values, color=f_inf_color[i], label='f_inf_fit')
axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_zero_frequencies(),
'o', color=f_zero_color[i], label='f_zero')
popt = data_origin[i].f_zero_fit
axes[i].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
color=f_zero_color[i], label='f_0_fit')
axes[i].set_xlabel("Stimulus value - contrast")
axes[i].legend()
axes[0].set_title("cell")
axes[0].set_ylabel("Frequency [Hz]")
axes[1].set_title("model")
median_baseline = np.median(data_fi_curve.get_f_baseline_frequencies())
axes[2].plot((min_x, max_x), (median_baseline, median_baseline), color=f_base_color[0], label="cell med base")
axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_baseline_frequencies(),
'o', color=f_base_color[1], label='model base')
y_values = [fu.clipped_line(x, data_fi_curve.f_inf_fit[0], data_fi_curve.f_inf_fit[1]) for x in x_values]
axes[2].plot(x_values, y_values, color=f_inf_color[0], label='f_inf_fit cell')
axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_inf_frequencies(),
'o', color=f_inf_color[1], label='f_inf model')
popt = data_fi_curve.f_zero_fit
axes[2].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
color=f_zero_color[0], label='f_0_fit cell')
axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_zero_frequencies(),
'o', color=f_zero_color[1], label='f_zero model')
axes[2].set_title("cell model comparision")
axes[2].set_xlabel("Stimulus value - contrast")
axes[2].legend()
if save_path is None:
plt.show()
else:
print("save")
plt.savefig(save_path + "fi_curve_comparision.png")
plt.close()
def plot_f_point_detections(self, save_path=None):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
class FICurveCellData(FICurve):
def __init__(self, cell_data: CellData, stimulus_values):
self.cell_data = cell_data
super().__init__(stimulus_values)
def calculate_all_frequency_points(self):
mean_frequencies = self.cell_data.get_mean_isi_frequencies()
time_axes = self.cell_data.get_time_axes_mean_frequencies()
stimulus_start = self.cell_data.get_stimulus_start()
stimulus_duration = self.cell_data.get_stimulus_duration()
sampling_interval = self.cell_data.get_sampling_interval()
if len(mean_frequencies) == 0:
warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n"
"Was all_calculate_mean_isi_frequencies already called?")
for i in range(len(mean_frequencies)):
if time_axes[i][0] > self.cell_data.get_stimulus_start():
raise ValueError("TODO: Deal with to strongly cut frequency traces in cell data! ")
# self.f_zero_frequencies.append(-1)
# self.f_baseline_frequencies.append(-1)
# self.f_inf_frequencies.append(-1)
# continue
f_zero = hF.detect_f_zero_in_frequency_trace(time_axes[i], mean_frequencies[i],
stimulus_start, sampling_interval)
self.f_zero_frequencies.append(f_zero)
f_baseline = hF.detect_f_baseline_in_freq_trace(time_axes[i], mean_frequencies[i],
stimulus_start, sampling_interval)
self.f_baseline_frequencies.append(f_baseline)
f_infinity = hF.detect_f_infinity_in_freq_trace(time_axes[i], mean_frequencies[i],
stimulus_start, stimulus_duration, sampling_interval)
self.f_inf_frequencies.append(f_infinity)
def get_f_zero_inverse_at_frequency(self, frequency):
# UNUSED
b_vars = self.f_zero_fit
return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3])
def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value):
# UNUSED
infty_vars = self.f_inf_fit
return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1])
# def get_fi_curve_slope_at(self, stimulus_value):
# fit_vars = self.f_zero_fit
# return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
#
# def get_fi_curve_slope_of_straight(self):
# fit_vars = self.f_zero_fit
# return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
# def get_f_zero_and_f_inf_intersection(self):
# x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001)
# fit_vars = self.f_zero_fit
# f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
# f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1])
#
# intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten()
# # print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies])
# if len(intersection_indicies) > 1:
# f_baseline = np.median(self.f_baseline_frequencies)
# best_dist = np.inf
# best_idx = -1
# for idx in intersection_indicies:
# dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline)
# if dist < best_dist:
# best_dist = dist
# best_idx = idx
#
# return x_values[best_idx]
#
# elif len(intersection_indicies) == 0:
# raise ValueError("No intersection found!")
# else:
# return x_values[intersection_indicies[0]]
# def get_fi_curve_slope_at_f_zero_intersection(self):
# x = self.get_f_zero_and_f_inf_intersection()
# fit_vars = self.f_zero_fit
# return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
# def plot_fi_curve(self, savepath: str = None, comp_f_baselines=None, comp_f_zeros=None, comp_f_infs=None):
# min_x = min(self.stimulus_values)
# max_x = max(self.stimulus_values)
# step = (max_x - min_x) / 5000
# x_values = np.arange(min_x, max_x, step)
#
# plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base')
# if comp_f_baselines is not None:
# plt.plot(self.stimulus_values, comp_f_baselines, 'o', color='skyblue', label='comp_values base')
#
# plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf')
# plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values],
# color='darkgreen', label='f_inf_fit')
# if comp_f_infs is not None:
# plt.plot(self.stimulus_values, comp_f_infs, 'o', color='lime', label='comp values f_inf')
#
# plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero')
# popt = self.f_zero_fit
# plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
# color='red', label='f_0_fit')
# if comp_f_zeros is not None:
# plt.plot(self.stimulus_values, comp_f_zeros, 'o', color='wheat', label='comp_values f_zero')
#
# plt.legend()
# plt.ylabel("Frequency [Hz]")
# plt.xlabel("Stimulus value")
#
# if savepath is None:
# plt.show()
# else:
# print("save")
# plt.savefig(savepath + "fi_curve.png")
# plt.close()
def plot_f_point_detections(self, save_path=None):
mean_frequencies = np.array(self.cell_data.get_mean_isi_frequencies())
time_axes = self.cell_data.get_time_axes_mean_frequencies()
for i in range(len(mean_frequencies)):
fig, axes = plt.subplots(1, 1, sharex="all")
axes.plot(time_axes[i], mean_frequencies[i], label="voltage")
axes.plot((time_axes[i][0], time_axes[i][-1]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]), label="f_zero")
axes.plot((time_axes[i][0], time_axes[i][-1]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]), '--', label="f_inf")
axes.plot((time_axes[i][0], time_axes[i][-1]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]), label="f_base")
axes.set_title(str(self.stimulus_values[i]))
plt.legend()
if save_path is None:
plt.show()
else:
plt.savefig(save_path + "GENERATE_NAMES.png")
plt.close()
class FICurveModel(FICurve):
stim_duration = 0.5
stim_start = 0.5
total_simulation_time = stim_duration + 2 * stim_start
def __init__(self, model, stimulus_values, eod_frequency, trials=5):
self.eod_frequency = eod_frequency
self.model = model
self.trials = trials
self.spiketimes_array = np.zeros((len(stimulus_values), trials), dtype=list)
self.mean_frequency_traces = []
self.mean_time_traces = []
super().__init__(stimulus_values)
def calculate_all_frequency_points(self):
sampling_interval = self.model.get_sampling_interval()
self.f_inf_frequencies = []
self.f_zero_frequencies = []
self.f_baseline_frequencies = []
for i, c in enumerate(self.stimulus_values):
stimulus = SinusoidalStepStimulus(self.eod_frequency, c, self.stim_start, self.stim_duration)
frequency_traces = []
time_traces = []
for j in range(self.trials):
_, spiketimes = self.model.simulate_fast(stimulus, self.total_simulation_time)
self.spiketimes_array[i, j] = spiketimes
trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
frequency_traces.append(trial_frequency)
time_traces.append(trial_time)
time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval)
self.mean_frequency_traces.append(frequency)
self.mean_time_traces.append(time)
if len(time) == 0 or min(time) > self.stim_start \
or max(time) < self.stim_start + self.stim_duration:
print("Too few spikes to calculate f_inf, f_0 and f_base")
self.f_inf_frequencies.append(0)
self.f_zero_frequencies.append(0)
self.f_baseline_frequencies.append(0)
continue
f_inf = hF.detect_f_infinity_in_freq_trace(time, frequency, self.stim_start, self.stim_duration, sampling_interval)
self.f_inf_frequencies.append(f_inf)
f_zero = hF.detect_f_zero_in_frequency_trace(time, frequency, self.stim_start, sampling_interval)
self.f_zero_frequencies.append(f_zero)
f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval)
self.f_baseline_frequencies.append(f_baseline)
def plot_f_point_detections(self, save_path=None):
sampling_interval = self.model.get_sampling_interval()
for i, c in enumerate(self.stimulus_values):
time = self.mean_time_traces[i]
frequency = self.mean_frequency_traces[i]
if len(time) == 0 or min(time) > self.stim_start \
or max(time) < self.stim_start + self.stim_duration:
continue
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.plot(time, frequency)
start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], self.stim_start, sampling_interval)
ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
label="f_base", color="deepskyblue")
start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], self.stim_start, self.stim_duration, sampling_interval)
ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
label="f_inf", color="limegreen")
start_idx, end_idx = hF.time_window_detect_f_zero(time[0], self.stim_start, sampling_interval)
ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
label="f_zero", color="orange")
plt.legend()
if save_path is not None:
plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
else:
plt.show()
plt.close()
def get_fi_curve_class(data, stimulus_values, eod_freq=None) -> FICurve:
if isinstance(data, CellData):
return FICurveCellData(data, stimulus_values)
if isinstance(data, LifacNoiseModel):
if eod_freq is None:
raise ValueError("The FiCurveModel needs the eod variable to work")
return FICurveModel(data, stimulus_values, eod_freq)
raise ValueError("Unknown type: Cannot find corresponding Baseline class. Data was type:" + str(type(data)))