restructure project
This commit is contained in:
219
experiments/AdaptionCurrent.py
Normal file
219
experiments/AdaptionCurrent.py
Normal file
@@ -0,0 +1,219 @@
|
||||
|
||||
from FiCurve import FICurve
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.optimize import curve_fit
|
||||
import os
|
||||
import numpy as np
|
||||
from my_util import functions as fu
|
||||
|
||||
|
||||
class Adaption:
|
||||
|
||||
def __init__(self, fi_curve: FICurve):
|
||||
|
||||
self.fi_curve = fi_curve
|
||||
|
||||
# [[a, tau_eff, c], [], [a, tau_eff, c], ...]
|
||||
self.exponential_fit_vars = []
|
||||
self.tau_real = []
|
||||
|
||||
self.fit_exponential()
|
||||
self.calculate_tau_from_tau_eff()
|
||||
|
||||
def fit_exponential(self, length_of_fit=0.1):
|
||||
time_axes, mean_frequencies = self.fi_curve.get_mean_time_and_freq_traces()
|
||||
f_baselines = self.fi_curve.get_f_baseline_frequencies()
|
||||
f_infinities = self.fi_curve.get_f_inf_frequencies()
|
||||
f_zeros = self.fi_curve.get_f_zero_frequencies()
|
||||
for i in range(len(mean_frequencies)):
|
||||
|
||||
if abs(f_zeros[i] - f_infinities[i]) < 20:
|
||||
self.exponential_fit_vars.append([])
|
||||
continue
|
||||
|
||||
start_idx = self.__find_start_idx_for_exponential_fit(time_axes[i], mean_frequencies[i],
|
||||
f_baselines[i], f_infinities[i], f_zeros[i])
|
||||
|
||||
if start_idx == -1:
|
||||
# print("start index negative")
|
||||
self.exponential_fit_vars.append([])
|
||||
continue
|
||||
|
||||
# shorten length of fit to stay in stimulus region if given length is too long
|
||||
sampling_interval = self.fi_curve.get_sampling_interval()
|
||||
used_length_of_fit = length_of_fit
|
||||
if (start_idx * sampling_interval) - self.fi_curve.get_delay() + length_of_fit > self.fi_curve.get_stimulus_end():
|
||||
print(start_idx * sampling_interval, "start - end", start_idx * sampling_interval + length_of_fit)
|
||||
print("Shortened length of fit to keep it in the stimulus region!")
|
||||
used_length_of_fit = self.fi_curve.get_stimulus_end() - (start_idx * sampling_interval)
|
||||
|
||||
|
||||
|
||||
end_idx = start_idx + int(used_length_of_fit/sampling_interval)
|
||||
y_values = mean_frequencies[i][start_idx:end_idx+1]
|
||||
x_values = time_axes[i][start_idx:end_idx+1]
|
||||
plt.title("f_zero {:.2f}, f_inf {:.2f}".format(f_zeros[i], f_infinities[i]))
|
||||
plt.plot(time_axes[i], mean_frequencies[i])
|
||||
plt.plot(x_values, y_values)
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
tau = self.__approximate_tau_for_exponential_fit(x_values, y_values, i)
|
||||
|
||||
# start the actual fit:
|
||||
try:
|
||||
p0 = (self.fi_curve.f_zero_frequencies[i], tau, self.fi_curve.f_inf_frequencies[i])
|
||||
popt, pcov = curve_fit(fu.exponential_function, x_values, y_values,
|
||||
p0=p0, maxfev=10000, bounds=([-np.inf, 0, -np.inf], [np.inf, np.inf, np.inf]))
|
||||
|
||||
# plt.plot(time_axes[i], mean_frequencies[i])
|
||||
# plt.plot(x_values, [fu.exponential_function(x, popt[0], popt[1], popt[2]) for x in x_values])
|
||||
# plt.show()
|
||||
# plt.close()
|
||||
|
||||
except RuntimeError:
|
||||
print("RuntimeError happened in fit_exponential.")
|
||||
self.exponential_fit_vars.append([])
|
||||
continue
|
||||
|
||||
# Obviously a bad fit - time constant, expected in range 3-10ms, has value over 1 second or is negative
|
||||
if abs(popt[1] > 1) or popt[1] < 0:
|
||||
print("detected an obviously bad fit")
|
||||
self.exponential_fit_vars.append([])
|
||||
else:
|
||||
self.exponential_fit_vars.append(popt)
|
||||
|
||||
def __approximate_tau_for_exponential_fit(self, x_values, y_values, mean_freq_idx):
|
||||
if self.fi_curve.f_inf_frequencies[mean_freq_idx] < self.fi_curve.f_baseline_frequencies[mean_freq_idx] * 0.95:
|
||||
test_val = [y > 0.65 * self.fi_curve.f_inf_frequencies[mean_freq_idx] for y in y_values]
|
||||
else:
|
||||
test_val = [y < 0.65 * self.fi_curve.f_zero_frequencies[mean_freq_idx] for y in y_values]
|
||||
|
||||
try:
|
||||
idx = test_val.index(True)
|
||||
if idx == 0:
|
||||
idx = 1
|
||||
tau = x_values[idx] - x_values[0]
|
||||
except ValueError:
|
||||
tau = x_values[-1] - x_values[0]
|
||||
|
||||
return tau
|
||||
|
||||
def __find_start_idx_for_exponential_fit(self, time, frequency, f_base, f_inf, f_zero):
|
||||
|
||||
# plt.plot(time, frequency)
|
||||
# plt.plot((time[0], time[-1]), (f_base, f_base), "-.")
|
||||
# plt.plot((time[0], time[-1]), (f_inf, f_inf), "-")
|
||||
# plt.plot((time[0], time[-1]), (f_zero, f_zero))
|
||||
|
||||
stimulus_start_idx = int((self.fi_curve.get_stimulus_start() - time[0]) / self.fi_curve.get_sampling_interval())
|
||||
|
||||
# plt.plot((time[stimulus_start_idx], ), (0, ), 'o')
|
||||
#
|
||||
# plt.show()
|
||||
# plt.close()
|
||||
|
||||
if f_inf > f_base * 1.1:
|
||||
# start setting starting variables for the fit
|
||||
# search for the start_index by searching for the max
|
||||
j = 0
|
||||
while True:
|
||||
try:
|
||||
if frequency[stimulus_start_idx + j] == f_zero:
|
||||
start_idx = stimulus_start_idx + j
|
||||
break
|
||||
except IndexError as e:
|
||||
return -1
|
||||
|
||||
j += 1
|
||||
|
||||
elif f_inf < f_base * 0.9:
|
||||
# start setting starting variables for the fit
|
||||
# search for start by finding the end of the minimum
|
||||
found_min = False
|
||||
j = int(0.05 / self.fi_curve.get_sampling_interval())
|
||||
nothing_to_fit = False
|
||||
while True:
|
||||
if not found_min:
|
||||
if frequency[stimulus_start_idx + j] == f_zero:
|
||||
found_min = True
|
||||
else:
|
||||
if frequency[stimulus_start_idx + j + 1] > f_zero:
|
||||
start_idx = stimulus_start_idx + j
|
||||
break
|
||||
if j > 0.1 / self.fi_curve.get_sampling_interval():
|
||||
# no rise in freq until to close to the end of the stimulus (to little place to fit)
|
||||
return -1
|
||||
j += 1
|
||||
|
||||
if nothing_to_fit:
|
||||
return -1
|
||||
else:
|
||||
# there is nothing to fit to:
|
||||
return -1
|
||||
|
||||
# plt.plot(time, frequency)
|
||||
# plt.plot(time[start_idx], frequency[start_idx], 'o')
|
||||
# plt.show()
|
||||
# plt.close()
|
||||
|
||||
return start_idx
|
||||
|
||||
def calculate_tau_from_tau_eff(self):
|
||||
tau_effs = []
|
||||
indices = []
|
||||
for i in range(len(self.exponential_fit_vars)):
|
||||
if len(self.exponential_fit_vars[i]) == 0:
|
||||
continue
|
||||
indices.append(i)
|
||||
tau_effs.append(self.exponential_fit_vars[i][1])
|
||||
|
||||
f_infinity_slope = self.fi_curve.get_f_inf_slope()
|
||||
approx_tau_reals = []
|
||||
for i, idx in enumerate(indices):
|
||||
factor = self.fi_curve.get_f_zero_fit_slope_at_stimulus_value(self.fi_curve.stimulus_values[idx]) / f_infinity_slope
|
||||
approx_tau_reals.append(tau_effs[i] * factor)
|
||||
|
||||
self.tau_real = np.median(approx_tau_reals)
|
||||
|
||||
def get_tau_real(self):
|
||||
return np.median(self.tau_real)
|
||||
|
||||
def get_tau_effs(self):
|
||||
return [ex_vars[1] for ex_vars in self.exponential_fit_vars if ex_vars != []]
|
||||
|
||||
def get_delta_a(self):
|
||||
return self.fi_curve.get_f_zero_fit_slope_at_straight() / self.fi_curve.get_f_inf_slope() / 100
|
||||
|
||||
def plot_exponential_fits(self, save_path: str = None, indices: list = None, delete_previous: bool = False):
|
||||
if delete_previous:
|
||||
for val in self.fi_curve.stimulus_values():
|
||||
|
||||
prev_path = save_path + "mean_freq_exp_fit_contrast:" + str(round(val, 3)) + ".png"
|
||||
|
||||
if os.path.exists(prev_path):
|
||||
os.remove(prev_path)
|
||||
|
||||
time_axes, mean_freqs = self.fi_curve.get_mean_time_and_freq_traces()
|
||||
for i in range(len(self.fi_curve.stimulus_values)):
|
||||
if indices is not None and i not in indices:
|
||||
continue
|
||||
|
||||
if self.exponential_fit_vars[i] == []:
|
||||
print("no fit vars for index {}!".format(i))
|
||||
continue
|
||||
|
||||
plt.plot(time_axes[i], mean_freqs[i])
|
||||
vars = self.exponential_fit_vars[i]
|
||||
fit_x = np.arange(0, 0.4, self.fi_curve.get_sampling_interval())
|
||||
plt.plot(fit_x, [fu.exponential_function(x, vars[0], vars[1], vars[2]) for x in fit_x])
|
||||
plt.ylim([0, max(self.fi_curve.f_zero_frequencies[i], self.fi_curve.f_baseline_frequencies[i])*1.1])
|
||||
plt.xlabel("Time [s]")
|
||||
plt.ylabel("Frequency [Hz]")
|
||||
|
||||
if save_path is None:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(save_path + "mean_freq_exp_fit_contrast:" + str(round(self.fi_curve.stimulus_values[i], 3)) + ".png")
|
||||
|
||||
plt.close()
|
||||
410
experiments/Baseline.py
Normal file
410
experiments/Baseline.py
Normal file
@@ -0,0 +1,410 @@
|
||||
|
||||
from parser.CellData import CellData
|
||||
from models.LIFACnoise import LifacNoiseModel
|
||||
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
|
||||
from my_util import helperFunctions as hF
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import pickle
|
||||
from os.path import join, exists
|
||||
|
||||
|
||||
class Baseline:
|
||||
|
||||
def __init__(self):
|
||||
self.save_file_name = "baseline_values.pkl"
|
||||
self.baseline_frequency = -1
|
||||
self.serial_correlation = []
|
||||
self.vector_strength = -1
|
||||
self.coefficient_of_variation = -1
|
||||
self.burstiness = -1
|
||||
|
||||
def get_baseline_frequency(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_serial_correlation(self, max_lag):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_vector_strength(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_coefficient_of_variation(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_burstiness(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def __get_burstiness__(self, eod_freq):
|
||||
isis = np.array(self.get_interspike_intervals())
|
||||
if len(isis) == 0:
|
||||
return 0
|
||||
|
||||
fullfilled = isis < (2.5 / eod_freq)
|
||||
perc_bursts = np.sum(fullfilled) / len(fullfilled)
|
||||
|
||||
return perc_bursts * (np.mean(isis)*1000)
|
||||
|
||||
def get_interspike_intervals(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_spiketime_phases(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def plot_baseline(self, save_path=None, time_length=0.2):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
@staticmethod
|
||||
def _get_baseline_frequency_given_data(spiketimes):
|
||||
base_freqs = []
|
||||
for st in spiketimes:
|
||||
base_freqs.append(hF.calculate_mean_isi_freq(st))
|
||||
|
||||
return np.median(base_freqs)
|
||||
|
||||
@staticmethod
|
||||
def _get_serial_correlation_given_data(max_lag, spikestimes):
|
||||
serial_cors = []
|
||||
|
||||
for st in spikestimes:
|
||||
sc = hF.calculate_serial_correlation(st, max_lag)
|
||||
serial_cors.append(sc)
|
||||
serial_cors = np.array(serial_cors)
|
||||
|
||||
res = np.mean(serial_cors, axis=0)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def _get_vector_strength_given_data(times, eods, spiketimes, sampling_interval):
|
||||
vs_per_trial = []
|
||||
for i in range(len(spiketimes)):
|
||||
vs = hF.calculate_vector_strength_from_spiketimes(times[i], eods[i], spiketimes[i], sampling_interval)
|
||||
vs_per_trial.append(vs)
|
||||
|
||||
return np.mean(vs_per_trial)
|
||||
|
||||
@staticmethod
|
||||
def _get_coefficient_of_variation_given_data(spiketimes):
|
||||
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
|
||||
cvs = []
|
||||
for st in spiketimes:
|
||||
st = np.array(st)
|
||||
cvs.append(hF.calculate_coefficient_of_variation(st))
|
||||
|
||||
return np.mean(cvs)
|
||||
|
||||
@staticmethod
|
||||
def _get_interspike_intervals_given_data(spiketimes):
|
||||
isis = []
|
||||
for st in spiketimes:
|
||||
st = np.array(st)
|
||||
isis.extend(np.diff(st))
|
||||
|
||||
return isis
|
||||
|
||||
@staticmethod
|
||||
def _plot_baseline_given_data(time, eod, v1, spiketimes, sampling_interval, eod_freq="", save_path=None, position=0.5, time_length=0.2):
|
||||
"""
|
||||
plots the stimulus / eod, together with the v1, spiketimes and frequency
|
||||
:return:
|
||||
"""
|
||||
length_data_points = int(time_length / sampling_interval)
|
||||
|
||||
start_idx = int(len(time) * position)
|
||||
start_idx = start_idx if start_idx >= 0 else 0
|
||||
end_idx = int(len(time) * position + length_data_points) + 1
|
||||
end_idx = end_idx if end_idx <= len(time) else len(time)
|
||||
|
||||
spiketimes = np.array(spiketimes)
|
||||
spiketimes_part = spiketimes[(spiketimes >= time[start_idx]) & (spiketimes < time[end_idx])]
|
||||
|
||||
fig, axes = plt.subplots(3, 1, sharex="col", figsize=(12, 8))
|
||||
fig.suptitle("Baseline middle part ({:.2f} seconds)".format(time_length))
|
||||
axes[0].plot(time[start_idx:end_idx], eod[start_idx:end_idx])
|
||||
axes[0].set_ylabel("Stimulus [mV] - Freq:" + eod_freq)
|
||||
|
||||
max_v1 = max(v1[start_idx:end_idx])
|
||||
axes[1].plot(time[start_idx:end_idx], v1[start_idx:end_idx])
|
||||
axes[1].plot(spiketimes_part, [max_v1 for _ in range(len(spiketimes_part))],
|
||||
'o', color='orange')
|
||||
axes[1].set_ylabel("V1-Trace [mV]")
|
||||
|
||||
t, f = hF.calculate_time_and_frequency_trace(spiketimes_part, sampling_interval)
|
||||
axes[2].plot(t, f)
|
||||
axes[2].set_ylabel("ISI-Frequency [Hz]")
|
||||
axes[2].set_xlabel("Time [s]")
|
||||
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path + "baseline.png")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
@staticmethod
|
||||
def plot_isi_histogram_comparision(cell_isis, model_isis, save_path=None):
|
||||
cell_isis = np.array(cell_isis) * 1000
|
||||
model_isis = np.array(model_isis) * 1000
|
||||
maximum = max(max(cell_isis), max(model_isis))
|
||||
bins = np.arange(0, maximum * 1.01, 0.1)
|
||||
|
||||
plt.title('Baseline ISIs')
|
||||
plt.xlabel('ISI in ms')
|
||||
plt.ylabel('Count')
|
||||
plt.hist(cell_isis, bins=bins, label="cell", alpha=0.5, density=True)
|
||||
plt.hist(model_isis, bins=bins, label="model", alpha=0.5, density=True)
|
||||
plt.legend()
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path + "isi-histogram_comparision.png")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
def plot_polar_vector_strength(self, save_path=None):
|
||||
phases = self.get_spiketime_phases()
|
||||
fig = plt.figure()
|
||||
ax = fig.add_subplot(111, polar=True)
|
||||
# r = np.arange(0, 1, 0.001)
|
||||
# theta = 2 * 2 * np.pi * r
|
||||
# line, = ax.plot(theta, r, color='#ee8d18', lw=3)
|
||||
bins = np.arange(0, np.pi * 2, 0.1)
|
||||
ax.hist(phases, bins=bins)
|
||||
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path + "vector_strength_polar_plot.png")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
def plot_interspike_interval_histogram(self, save_path=None):
|
||||
|
||||
isi = np.array(self.get_interspike_intervals()) * 1000 # change unit to milliseconds
|
||||
if len(isi) == 0:
|
||||
print("NON SPIKES IN BASELINE OF CELL/MODEL")
|
||||
plt.title('Baseline ISIs - NO SPIKES!')
|
||||
plt.xlabel('ISI in ms')
|
||||
plt.ylabel('Count')
|
||||
plt.hist(isi, bins=np.arange(0, 1, 0.1))
|
||||
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path + "isi-histogram.png")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
return
|
||||
maximum = max(isi)
|
||||
bins = np.arange(0, maximum * 1.01, 0.1)
|
||||
|
||||
plt.title('Baseline ISIs')
|
||||
plt.xlabel('ISI in ms')
|
||||
plt.ylabel('Count')
|
||||
plt.hist(isi, bins=bins)
|
||||
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path + "isi-histogram.png")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
def plot_serial_correlation(self, max_lag, save_path=None):
|
||||
plt.title("Baseline Serial correlation")
|
||||
plt.xlabel("Lag")
|
||||
plt.ylabel("Correlation")
|
||||
plt.ylim((-1, 1))
|
||||
plt.plot(np.arange(1, max_lag+1, 1), self.get_serial_correlation(max_lag))
|
||||
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path + "serial_correlation.png")
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
def save_values(self, save_directory):
|
||||
values = {}
|
||||
values["baseline_frequency"] = self.get_baseline_frequency()
|
||||
values["serial correlation"] = self.get_serial_correlation(max_lag=10)
|
||||
values["vector strength"] = self.get_vector_strength()
|
||||
values["coefficient of variation"] = self.get_coefficient_of_variation()
|
||||
values["burstiness"] = self.get_burstiness()
|
||||
|
||||
with open(join(save_directory, self.save_file_name), "wb") as file:
|
||||
pickle.dump(values, file)
|
||||
print("Baseline: Values saved!")
|
||||
|
||||
def load_values(self, save_directory):
|
||||
file_path = join(save_directory, self.save_file_name)
|
||||
if not exists(file_path):
|
||||
print("Baseline: No file to load")
|
||||
return False
|
||||
|
||||
file = open(file_path, "rb")
|
||||
values = pickle.load(file)
|
||||
self.baseline_frequency = values["baseline_frequency"]
|
||||
self.serial_correlation = values["serial correlation"]
|
||||
self.vector_strength = values["vector strength"]
|
||||
self.coefficient_of_variation = values["coefficient of variation"]
|
||||
self.burstiness = values["burstiness"]
|
||||
print("Baseline: Values loaded!")
|
||||
return True
|
||||
|
||||
|
||||
class BaselineCellData(Baseline):
|
||||
|
||||
def __init__(self, cell_data: CellData):
|
||||
super().__init__()
|
||||
self.data = cell_data
|
||||
|
||||
def get_baseline_frequency(self):
|
||||
if self.baseline_frequency == -1:
|
||||
spiketimes = self.data.get_base_spikes()
|
||||
self.baseline_frequency = self._get_baseline_frequency_given_data(spiketimes)
|
||||
|
||||
return self.baseline_frequency
|
||||
|
||||
def get_vector_strength(self):
|
||||
if self.vector_strength == -1:
|
||||
times = self.data.get_base_traces(self.data.TIME)
|
||||
eods = self.data.get_base_traces(self.data.EOD)
|
||||
spiketimes = self.data.get_base_spikes()
|
||||
sampling_interval = self.data.get_sampling_interval()
|
||||
self.vector_strength = self._get_vector_strength_given_data(times, eods, spiketimes, sampling_interval)
|
||||
return self.vector_strength
|
||||
|
||||
def get_serial_correlation(self, max_lag):
|
||||
if len(self.serial_correlation) < max_lag:
|
||||
self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.data.get_base_spikes())
|
||||
return self.serial_correlation[:max_lag]
|
||||
|
||||
def get_coefficient_of_variation(self):
|
||||
if self.coefficient_of_variation == -1:
|
||||
self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.data.get_base_spikes())
|
||||
return self.coefficient_of_variation
|
||||
|
||||
def get_interspike_intervals(self):
|
||||
return self._get_interspike_intervals_given_data(self.data.get_base_spikes())
|
||||
|
||||
def get_spiketime_phases(self):
|
||||
times = self.data.get_base_traces(self.data.TIME)
|
||||
spiketimes = self.data.get_base_spikes()
|
||||
eods = self.data.get_base_traces(self.data.EOD)
|
||||
sampling_interval = self.data.get_sampling_interval()
|
||||
|
||||
phase_list = []
|
||||
for i in range(len(times)):
|
||||
spiketime_indices = np.array(np.around((np.array(spiketimes[i]) + times[i][0]) / sampling_interval), dtype=int)
|
||||
rel_spikes, eod_durs = hF.eods_around_spikes(times[i], eods[i], spiketime_indices)
|
||||
|
||||
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
|
||||
phase_list.extend(phase_times)
|
||||
|
||||
return phase_list
|
||||
|
||||
def get_burstiness(self):
|
||||
if self.burstiness == -1:
|
||||
self.burstiness = self.__get_burstiness__(self.data.get_eod_frequency())
|
||||
return self.burstiness
|
||||
|
||||
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
|
||||
# eod, v1, spiketimes, frequency
|
||||
|
||||
time = self.data.get_base_traces(self.data.TIME)[0]
|
||||
eod = self.data.get_base_traces(self.data.EOD)[0]
|
||||
v1_trace = self.data.get_base_traces(self.data.V1)[0]
|
||||
spiketimes = self.data.get_base_spikes()[0]
|
||||
|
||||
self._plot_baseline_given_data(time, eod, v1_trace, spiketimes,
|
||||
self.data.get_sampling_interval(), "{:.0f}".format(self.data.get_eod_frequency()), save_path, position, time_length)
|
||||
|
||||
|
||||
class BaselineModel(Baseline):
|
||||
|
||||
simulation_time = 30
|
||||
|
||||
def __init__(self, model: LifacNoiseModel, eod_frequency, trials=1):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.eod_frequency = eod_frequency
|
||||
self.set_model_adaption_to_baseline()
|
||||
|
||||
self.stimulus = SinusoidalStepStimulus(eod_frequency, 0)
|
||||
self.eod = self.stimulus.as_array(0, self.simulation_time, model.get_sampling_interval())
|
||||
self.time = np.arange(0, self.simulation_time, model.get_sampling_interval())
|
||||
|
||||
self.v1_traces = []
|
||||
self.spiketimes = []
|
||||
for i in range(trials):
|
||||
v, st = model.simulate(self.stimulus, self.simulation_time)
|
||||
self.v1_traces.append(v)
|
||||
self.spiketimes.append(st)
|
||||
|
||||
def set_model_adaption_to_baseline(self):
|
||||
stimulus = SinusoidalStepStimulus(self.eod_frequency, 0, 0, 0)
|
||||
self.model.simulate(stimulus, 1)
|
||||
adaption = self.model.get_adaption_trace()
|
||||
self.model.set_variable("a_zero", adaption[-1])
|
||||
# print("Baseline: model a_zero set to", adaption[-1])
|
||||
|
||||
def get_baseline_frequency(self):
|
||||
if self.baseline_frequency == -1:
|
||||
self.baseline_frequency = self._get_baseline_frequency_given_data(self.spiketimes)
|
||||
return self.baseline_frequency
|
||||
|
||||
def get_vector_strength(self):
|
||||
if self.vector_strength == -1:
|
||||
times = [self.time] * len(self.spiketimes)
|
||||
eods = [self.eod] * len(self.spiketimes)
|
||||
sampling_interval = self.model.get_sampling_interval()
|
||||
self.vector_strength = self._get_vector_strength_given_data(times, eods, self.spiketimes, sampling_interval)
|
||||
|
||||
return self.vector_strength
|
||||
|
||||
def get_serial_correlation(self, max_lag):
|
||||
if len(self.serial_correlation) != max_lag:
|
||||
self.serial_correlation = self._get_serial_correlation_given_data(max_lag, self.spiketimes)
|
||||
return self.serial_correlation
|
||||
|
||||
def get_coefficient_of_variation(self):
|
||||
if self.coefficient_of_variation == -1:
|
||||
self.coefficient_of_variation = self._get_coefficient_of_variation_given_data(self.spiketimes)
|
||||
return self.coefficient_of_variation
|
||||
|
||||
def get_interspike_intervals(self):
|
||||
return self._get_interspike_intervals_given_data(self.spiketimes)
|
||||
|
||||
def get_burstiness(self):
|
||||
if self.burstiness == -1:
|
||||
self.burstiness = self.__get_burstiness__(self.eod_frequency)
|
||||
return self.burstiness
|
||||
|
||||
def get_spiketime_phases(self):
|
||||
sampling_interval = self.model.get_sampling_interval()
|
||||
|
||||
phase_list = []
|
||||
for i in range(len(self.spiketimes)):
|
||||
spiketime_indices = np.array(np.around((np.array(self.spiketimes[i]) + self.time[0]) / sampling_interval), dtype=int)
|
||||
rel_spikes, eod_durs = hF.eods_around_spikes(self.time, self.eod, spiketime_indices)
|
||||
|
||||
phase_times = (rel_spikes / eod_durs) * 2 * np.pi
|
||||
phase_list.extend(phase_times)
|
||||
|
||||
return phase_list
|
||||
|
||||
def plot_baseline(self, save_path=None, position=0.5, time_length=0.2):
|
||||
self._plot_baseline_given_data(self.time, self.eod, self.v1_traces[0], self.spiketimes[0],
|
||||
self.model.get_sampling_interval(), "{:.0f}".format(self.eod_frequency),
|
||||
save_path, position, time_length)
|
||||
|
||||
|
||||
def get_baseline_class(data, eod_freq=None, trials=1) -> Baseline:
|
||||
if isinstance(data, CellData):
|
||||
return BaselineCellData(data)
|
||||
if isinstance(data, LifacNoiseModel):
|
||||
if eod_freq is None:
|
||||
raise ValueError("The EOD frequency is needed for the BaselineModel Class.")
|
||||
return BaselineModel(data, eod_freq, trials=trials)
|
||||
|
||||
raise ValueError("Unknown type: Cannot find corresponding Baseline class. data was type:" + str(type(data)))
|
||||
533
experiments/FiCurve.py
Normal file
533
experiments/FiCurve.py
Normal file
@@ -0,0 +1,533 @@
|
||||
|
||||
from parser.CellData import CellData
|
||||
from models.LIFACnoise import LifacNoiseModel
|
||||
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from warnings import warn
|
||||
from my_util import helperFunctions as hF, functions as fu
|
||||
from os.path import join, exists
|
||||
import pickle
|
||||
from sys import stderr
|
||||
|
||||
|
||||
class FICurve:
|
||||
|
||||
def __init__(self, stimulus_values, save_dir=None, recalculate=False):
|
||||
self.save_file_name = "fi_curve_values.pkl"
|
||||
self.stimulus_values = stimulus_values
|
||||
|
||||
self.indices_f_baseline = []
|
||||
self.f_baseline_frequencies = []
|
||||
self.indices_f_inf = []
|
||||
self.f_inf_frequencies = []
|
||||
self.indices_f_zero = []
|
||||
self.f_zero_frequencies = []
|
||||
|
||||
# increase, offset
|
||||
self.f_inf_fit = []
|
||||
# f_max, f_min, k, x_zero
|
||||
self.f_zero_fit = []
|
||||
|
||||
if save_dir is None:
|
||||
self.initialize()
|
||||
else:
|
||||
if recalculate:
|
||||
self.initialize()
|
||||
self.save_values(save_dir)
|
||||
else:
|
||||
if not self.load_values(save_dir):
|
||||
self.initialize()
|
||||
self.save_values(save_dir)
|
||||
|
||||
def initialize(self):
|
||||
self.calculate_all_frequency_points()
|
||||
self.f_inf_fit = hF.fit_clipped_line(self.stimulus_values, self.f_inf_frequencies)
|
||||
self.f_zero_fit = hF.fit_boltzmann(self.stimulus_values, self.f_zero_frequencies)
|
||||
|
||||
def calculate_all_frequency_points(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_f_baseline_frequencies(self):
|
||||
return self.f_baseline_frequencies
|
||||
|
||||
def get_f_inf_frequencies(self):
|
||||
return self.f_inf_frequencies
|
||||
|
||||
def get_f_zero_frequencies(self):
|
||||
return self.f_zero_frequencies
|
||||
|
||||
def get_f_inf_slope(self):
|
||||
if len(self.f_inf_fit) > 0:
|
||||
return self.f_inf_fit[0]
|
||||
|
||||
def get_f_zero_fit_slope_at_straight(self):
|
||||
fit_vars = self.f_zero_fit
|
||||
return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
||||
|
||||
def get_f_zero_fit_slope_at_stimulus_value(self, stimulus_value):
|
||||
fit_vars = self.f_zero_fit
|
||||
return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
||||
|
||||
def get_f_inf_frequency_at_stimulus_value(self, stimulus_value):
|
||||
return fu.clipped_line(stimulus_value, self.f_inf_fit[0], self.f_inf_fit[1])
|
||||
|
||||
def get_f_zero_and_f_inf_intersection(self):
|
||||
x_values = np.arange(min(self.stimulus_values), max(self.stimulus_values), 0.0001)
|
||||
fit_vars = self.f_zero_fit
|
||||
f_zero = fu.full_boltzmann(x_values, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
||||
f_inf = fu.clipped_line(x_values, self.f_inf_fit[0], self.f_inf_fit[1])
|
||||
|
||||
intersection_indicies = np.argwhere(np.diff(np.sign(f_zero - f_inf))).flatten()
|
||||
# print("fi-curve calc intersection:", intersection_indicies, x_values[intersection_indicies])
|
||||
if len(intersection_indicies) > 1:
|
||||
f_baseline = np.median(self.f_baseline_frequencies)
|
||||
best_dist = np.inf
|
||||
best_idx = -1
|
||||
for idx in intersection_indicies:
|
||||
dist = abs(fu.clipped_line(x_values[idx], self.f_inf_fit[0], self.f_inf_fit[1]) - f_baseline)
|
||||
if dist < best_dist:
|
||||
best_dist = dist
|
||||
best_idx = idx
|
||||
|
||||
return x_values[best_idx]
|
||||
|
||||
elif len(intersection_indicies) == 0:
|
||||
raise ValueError("No intersection found!")
|
||||
else:
|
||||
return x_values[intersection_indicies[0]]
|
||||
|
||||
def get_f_zero_fit_slope_at_f_inf_fit_intersection(self):
|
||||
x = self.get_f_zero_and_f_inf_intersection()
|
||||
fit_vars = self.f_zero_fit
|
||||
return fu.derivative_full_boltzmann(x, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
|
||||
|
||||
def get_mean_time_and_freq_traces(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_time_and_freq_traces(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_sampling_interval(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_delay(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_stimulus_start(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_stimulus_end(self):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def get_stimulus_duration(self):
|
||||
return self.get_stimulus_end() - self.get_stimulus_start()
|
||||
|
||||
def plot_mean_frequency_curves(self, save_path=None):
|
||||
time_traces, freq_traces = self.get_time_and_freq_traces()
|
||||
mean_times, mean_freqs = self.get_mean_time_and_freq_traces()
|
||||
for i, sv in enumerate(self.stimulus_values):
|
||||
|
||||
for j in range(len(time_traces[i])):
|
||||
plt.plot(time_traces[i][j], freq_traces[i][j], color="gray", alpha=0.8)
|
||||
|
||||
plt.plot(mean_times[i], mean_freqs[i], color="black")
|
||||
plt.xlabel("Time [s]")
|
||||
plt.ylabel("Frequency [Hz]")
|
||||
|
||||
|
||||
plt.title("Mean frequency at contrast {:.2f} ({:} trials)".format(sv, len(time_traces[i])))
|
||||
if save_path is None:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(save_path + "mean_frequency_contrast_{:.2f}.png".format(sv))
|
||||
plt.close()
|
||||
|
||||
def plot_fi_curve(self, save_path=None):
|
||||
min_x = min(self.stimulus_values)
|
||||
max_x = max(self.stimulus_values)
|
||||
step = (max_x - min_x) / 5000
|
||||
x_values = np.arange(min_x, max_x, step)
|
||||
|
||||
plt.plot(self.stimulus_values, self.f_baseline_frequencies, color='blue', label='f_base')
|
||||
|
||||
plt.plot(self.stimulus_values, self.f_inf_frequencies, 'o', color='green', label='f_inf')
|
||||
plt.plot(x_values, [fu.clipped_line(x, self.f_inf_fit[0], self.f_inf_fit[1]) for x in x_values],
|
||||
color='darkgreen', label='f_inf_fit')
|
||||
|
||||
plt.plot(self.stimulus_values, self.f_zero_frequencies, 'o', color='orange', label='f_zero')
|
||||
popt = self.f_zero_fit
|
||||
plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
|
||||
color='red', label='f_0_fit')
|
||||
|
||||
plt.legend()
|
||||
plt.ylabel("Frequency [Hz]")
|
||||
plt.xlabel("Stimulus value")
|
||||
|
||||
if save_path is None:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(save_path + "fi_curve.png")
|
||||
plt.close()
|
||||
|
||||
@staticmethod
|
||||
def plot_fi_curve_comparision(data_fi_curve, model_fi_curve, save_path=None):
|
||||
min_x = min(min(data_fi_curve.stimulus_values), min(model_fi_curve.stimulus_values))
|
||||
max_x = max(max(data_fi_curve.stimulus_values), max(model_fi_curve.stimulus_values))
|
||||
step = (max_x - min_x) / 5000
|
||||
x_values = np.arange(min_x, max_x+step, step)
|
||||
|
||||
fig, axes = plt.subplots(1, 3, sharex="all", sharey='all', figsize=(15, 6))
|
||||
# plot baseline
|
||||
data_origin = (data_fi_curve, model_fi_curve)
|
||||
f_base_color = ("blue", "deepskyblue")
|
||||
f_inf_color = ("green", "limegreen")
|
||||
f_zero_color = ("red", "orange")
|
||||
for i in range(2):
|
||||
|
||||
axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_baseline_frequencies(),
|
||||
color=f_base_color[i], label='f_base')
|
||||
|
||||
axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_inf_frequencies(),
|
||||
'o', color=f_inf_color[i], label='f_inf')
|
||||
y_values = [fu.clipped_line(x, data_origin[i].f_inf_fit[0], data_origin[i].f_inf_fit[1]) for x in x_values]
|
||||
axes[i].plot(x_values, y_values, color=f_inf_color[i], label='f_inf_fit')
|
||||
|
||||
axes[i].plot(data_origin[i].stimulus_values, data_origin[i].get_f_zero_frequencies(),
|
||||
'o', color=f_zero_color[i], label='f_zero')
|
||||
popt = data_origin[i].f_zero_fit
|
||||
axes[i].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
|
||||
color=f_zero_color[i], label='f_0_fit')
|
||||
|
||||
axes[i].set_xlabel("Stimulus value - contrast")
|
||||
axes[i].legend()
|
||||
|
||||
axes[0].set_title("cell")
|
||||
axes[0].set_ylabel("Frequency [Hz]")
|
||||
axes[1].set_title("model")
|
||||
|
||||
median_baseline = np.median(data_fi_curve.get_f_baseline_frequencies())
|
||||
axes[2].plot((min_x, max_x), (median_baseline, median_baseline), color=f_base_color[0], label="cell med base")
|
||||
axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_baseline_frequencies(),
|
||||
'o', color=f_base_color[1], label='model base')
|
||||
|
||||
y_values = [fu.clipped_line(x, data_fi_curve.f_inf_fit[0], data_fi_curve.f_inf_fit[1]) for x in x_values]
|
||||
axes[2].plot(x_values, y_values, color=f_inf_color[0], label='f_inf_fit cell')
|
||||
axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_inf_frequencies(),
|
||||
'o', color=f_inf_color[1], label='f_inf model')
|
||||
|
||||
popt = data_fi_curve.f_zero_fit
|
||||
axes[2].plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
|
||||
color=f_zero_color[0], label='f_0_fit cell')
|
||||
axes[2].plot(model_fi_curve.stimulus_values, model_fi_curve.get_f_zero_frequencies(),
|
||||
'o', color=f_zero_color[1], label='f_zero model')
|
||||
axes[2].set_title("cell model comparision")
|
||||
axes[2].set_xlabel("Stimulus value - contrast")
|
||||
axes[2].legend()
|
||||
|
||||
if save_path is None:
|
||||
plt.show()
|
||||
else:
|
||||
plt.savefig(save_path + "fi_curve_comparision.png")
|
||||
plt.close()
|
||||
|
||||
def plot_f_point_detections(self, save_path=None):
|
||||
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
|
||||
|
||||
def save_values(self, save_directory):
|
||||
|
||||
values = {}
|
||||
values["stimulus_values"] = self.stimulus_values
|
||||
values["f_baseline_frequencies"] = self.f_baseline_frequencies
|
||||
values["f_inf_frequencies"] = self.f_inf_frequencies
|
||||
values["f_zero_frequencies"] = self.f_zero_frequencies
|
||||
values["f_inf_fit"] = self.f_inf_fit
|
||||
values["f_zero_fit"] = self.f_zero_fit
|
||||
|
||||
with open(join(save_directory, self.save_file_name), "wb") as file:
|
||||
pickle.dump(values, file)
|
||||
|
||||
print("Fi-Curve: Values saved!")
|
||||
|
||||
def load_values(self, save_directory):
|
||||
file_path = join(save_directory, self.save_file_name)
|
||||
if not exists(file_path):
|
||||
print("Fi-Curve: No file to load")
|
||||
return False
|
||||
|
||||
file = open(file_path, "rb")
|
||||
values = pickle.load(file)
|
||||
if set(values["stimulus_values"]) != set(self.stimulus_values):
|
||||
stderr.write("Fi-Curve:load_values() - Given stimulus values are different to the loaded ones!:\n "
|
||||
"given: {}\n loaded: {}\n".format(str(self.stimulus_values), str(values["stimulus_values"])))
|
||||
|
||||
self.stimulus_values = values["stimulus_values"]
|
||||
self.f_baseline_frequencies = values["f_baseline_frequencies"]
|
||||
self.f_inf_frequencies = values["f_inf_frequencies"]
|
||||
self.f_zero_frequencies = values["f_zero_frequencies"]
|
||||
self.f_inf_fit = values["f_inf_fit"]
|
||||
self.f_zero_fit = values["f_zero_fit"]
|
||||
print("Fi-Curve: Values loaded!")
|
||||
return True
|
||||
|
||||
|
||||
class FICurveCellData(FICurve):
|
||||
|
||||
def __init__(self, cell_data: CellData, stimulus_values, save_dir=None, recalculate=False):
|
||||
self.cell_data = cell_data
|
||||
super().__init__(stimulus_values, save_dir, recalculate)
|
||||
|
||||
def calculate_all_frequency_points(self):
|
||||
mean_frequencies = self.cell_data.get_mean_fi_curve_isi_frequencies()
|
||||
time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies()
|
||||
stimulus_start = self.cell_data.get_stimulus_start()
|
||||
stimulus_duration = self.cell_data.get_stimulus_duration()
|
||||
sampling_interval = self.cell_data.get_sampling_interval()
|
||||
|
||||
if len(mean_frequencies) == 0:
|
||||
warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n"
|
||||
"Was all_calculate_mean_isi_frequencies already called?")
|
||||
|
||||
for i in range(len(mean_frequencies)):
|
||||
|
||||
if time_axes[i][0] > self.cell_data.get_stimulus_start():
|
||||
raise ValueError("TODO: Deal with to strongly cut frequency traces in cell data! ")
|
||||
# self.f_zero_frequencies.append(-1)
|
||||
# self.f_baseline_frequencies.append(-1)
|
||||
# self.f_inf_frequencies.append(-1)
|
||||
# continue
|
||||
|
||||
f_zero, f_zero_idx = hF.detect_f_zero_in_frequency_trace(time_axes[i], mean_frequencies[i],
|
||||
stimulus_start, sampling_interval)
|
||||
self.f_zero_frequencies.append(f_zero)
|
||||
self.indices_f_zero.append(f_zero_idx)
|
||||
|
||||
f_baseline, f_base_idx = hF.detect_f_baseline_in_freq_trace(time_axes[i], mean_frequencies[i],
|
||||
stimulus_start, sampling_interval)
|
||||
self.f_baseline_frequencies.append(f_baseline)
|
||||
self.indices_f_baseline.append(f_base_idx)
|
||||
f_infinity, f_inf_idx = hF.detect_f_infinity_in_freq_trace(time_axes[i], mean_frequencies[i],
|
||||
stimulus_start, stimulus_duration, sampling_interval)
|
||||
self.f_inf_frequencies.append(f_infinity)
|
||||
self.indices_f_inf.append(f_inf_idx)
|
||||
|
||||
def get_mean_time_and_freq_traces(self):
|
||||
return self.cell_data.get_time_axes_fi_curve_mean_frequencies(), self.cell_data.get_mean_fi_curve_isi_frequencies()
|
||||
|
||||
def get_time_and_freq_traces(self):
|
||||
spiketimes = self.cell_data.get_fi_spiketimes()
|
||||
time_traces = []
|
||||
freq_traces = []
|
||||
for i in range(len(spiketimes)):
|
||||
trial_time_traces = []
|
||||
trial_freq_traces = []
|
||||
for j in range(len(spiketimes[i])):
|
||||
time, isi_freq = hF.calculate_time_and_frequency_trace(spiketimes[i][j], self.cell_data.get_sampling_interval())
|
||||
|
||||
trial_freq_traces.append(isi_freq)
|
||||
trial_time_traces.append(time)
|
||||
|
||||
time_traces.append(trial_time_traces)
|
||||
freq_traces.append(trial_freq_traces)
|
||||
|
||||
return time_traces, freq_traces
|
||||
|
||||
def get_sampling_interval(self):
|
||||
return self.cell_data.get_sampling_interval()
|
||||
|
||||
def get_delay(self):
|
||||
return self.cell_data.get_delay()
|
||||
|
||||
def get_stimulus_start(self):
|
||||
return self.cell_data.get_stimulus_start()
|
||||
|
||||
def get_stimulus_end(self):
|
||||
return self.cell_data.get_stimulus_end()
|
||||
|
||||
def get_f_zero_inverse_at_frequency(self, frequency):
|
||||
# UNUSED
|
||||
b_vars = self.f_zero_fit
|
||||
return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3])
|
||||
|
||||
def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value):
|
||||
# UNUSED
|
||||
infty_vars = self.f_inf_fit
|
||||
return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1])
|
||||
|
||||
def plot_f_point_detections(self, save_path=None):
|
||||
mean_frequencies = np.array(self.cell_data.get_mean_fi_curve_isi_frequencies())
|
||||
time_axes = self.cell_data.get_time_axes_fi_curve_mean_frequencies()
|
||||
sampling_interval = self.cell_data.get_sampling_interval()
|
||||
stim_start = self.cell_data.get_stimulus_start()
|
||||
stim_duration = self.cell_data.get_stimulus_duration()
|
||||
|
||||
for i, c in enumerate(self.stimulus_values):
|
||||
time = time_axes[i]
|
||||
frequency = mean_frequencies[i]
|
||||
|
||||
if len(time) == 0 or min(time) > stim_start \
|
||||
or max(time) < stim_start + stim_duration:
|
||||
continue
|
||||
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
|
||||
ax.set_title("Stimulus value: {:.2f}".format(c))
|
||||
ax.plot(time, frequency)
|
||||
start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], stim_start, sampling_interval)
|
||||
ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
|
||||
label="f_base", color="deepskyblue")
|
||||
|
||||
start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], stim_start, stim_duration,
|
||||
sampling_interval)
|
||||
ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
|
||||
label="f_inf", color="limegreen")
|
||||
|
||||
start_idx, end_idx = hF.time_window_detect_f_zero(time[0], stim_start, sampling_interval)
|
||||
ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
|
||||
label="f_zero", color="orange")
|
||||
|
||||
plt.legend()
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
class FICurveModel(FICurve):
|
||||
stim_duration = 0.5
|
||||
stim_start = 0.5
|
||||
total_simulation_time = stim_duration + 2 * stim_start
|
||||
|
||||
def __init__(self, model, stimulus_values, eod_frequency, trials=5):
|
||||
self.eod_frequency = eod_frequency
|
||||
self.model = model
|
||||
self.trials = trials
|
||||
self.spiketimes_array = np.zeros((len(stimulus_values), trials), dtype=list)
|
||||
self.mean_frequency_traces = []
|
||||
self.mean_time_traces = []
|
||||
self.set_model_adaption_to_baseline()
|
||||
super().__init__(stimulus_values)
|
||||
|
||||
def set_model_adaption_to_baseline(self):
|
||||
stimulus = SinusoidalStepStimulus(self.eod_frequency, 0, 0, 0)
|
||||
self.model.simulate(stimulus, 1)
|
||||
adaption = self.model.get_adaption_trace()
|
||||
self.model.set_variable("a_zero", adaption[-1])
|
||||
# print("FiCurve: model a_zero set to", adaption[-1])
|
||||
|
||||
def calculate_all_frequency_points(self):
|
||||
|
||||
sampling_interval = self.model.get_sampling_interval()
|
||||
self.f_inf_frequencies = []
|
||||
self.f_zero_frequencies = []
|
||||
self.f_baseline_frequencies = []
|
||||
|
||||
for i, c in enumerate(self.stimulus_values):
|
||||
stimulus = SinusoidalStepStimulus(self.eod_frequency, c, self.stim_start, self.stim_duration)
|
||||
frequency_traces = []
|
||||
time_traces = []
|
||||
for j in range(self.trials):
|
||||
|
||||
_, spiketimes = self.model.simulate(stimulus, self.total_simulation_time)
|
||||
self.spiketimes_array[i, j] = spiketimes
|
||||
trial_time, trial_frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
|
||||
frequency_traces.append(trial_frequency)
|
||||
time_traces.append(trial_time)
|
||||
|
||||
time, frequency = hF.calculate_mean_of_frequency_traces(time_traces, frequency_traces, sampling_interval)
|
||||
self.mean_frequency_traces.append(frequency)
|
||||
self.mean_time_traces.append(time)
|
||||
|
||||
if len(time) == 0 or min(time) > self.stim_start \
|
||||
or max(time) < self.stim_start + self.stim_duration:
|
||||
# print("Too few spikes to calculate f_inf, f_0 and f_base")
|
||||
self.f_inf_frequencies.append(0)
|
||||
self.f_zero_frequencies.append(0)
|
||||
self.f_baseline_frequencies.append(0)
|
||||
continue
|
||||
|
||||
f_inf, f_inf_idx = hF.detect_f_infinity_in_freq_trace(time, frequency, self.stim_start, self.stim_duration, sampling_interval)
|
||||
self.f_inf_frequencies.append(f_inf)
|
||||
self.indices_f_inf.append(f_inf_idx)
|
||||
|
||||
f_zero, f_zero_idx = hF.detect_f_zero_in_frequency_trace(time, frequency, self.stim_start, sampling_interval)
|
||||
self.f_zero_frequencies.append(f_zero)
|
||||
self.indices_f_zero.append(f_zero_idx)
|
||||
|
||||
f_baseline, f_base_idx = hF.detect_f_baseline_in_freq_trace(time, frequency, self.stim_start, sampling_interval)
|
||||
self.f_baseline_frequencies.append(f_baseline)
|
||||
self.indices_f_baseline.append(f_base_idx)
|
||||
|
||||
def get_mean_time_and_freq_traces(self):
|
||||
return self.mean_time_traces, self.mean_frequency_traces
|
||||
|
||||
def get_sampling_interval(self):
|
||||
return self.model.get_sampling_interval()
|
||||
|
||||
def get_delay(self):
|
||||
return 0
|
||||
|
||||
def get_stimulus_start(self):
|
||||
return self.stim_start
|
||||
|
||||
def get_stimulus_end(self):
|
||||
return self.stim_start + self.stim_duration
|
||||
|
||||
def get_time_and_freq_traces(self):
|
||||
time_traces = []
|
||||
freq_traces = []
|
||||
for v in range(len(self.stimulus_values)):
|
||||
times_for_value = []
|
||||
freqs_for_value = []
|
||||
|
||||
for s in self.spiketimes_array[v]:
|
||||
t, f = hF.calculate_time_and_frequency_trace(s, self.model.get_sampling_interval())
|
||||
times_for_value.append(t)
|
||||
freqs_for_value.append(f)
|
||||
|
||||
time_traces.append(times_for_value)
|
||||
freq_traces.append(freqs_for_value)
|
||||
return time_traces, freq_traces
|
||||
|
||||
def plot_f_point_detections(self, save_path=None):
|
||||
sampling_interval = self.model.get_sampling_interval()
|
||||
|
||||
for i, c in enumerate(self.stimulus_values):
|
||||
time = self.mean_time_traces[i]
|
||||
frequency = self.mean_frequency_traces[i]
|
||||
|
||||
if len(time) == 0 or min(time) > self.stim_start \
|
||||
or max(time) < self.stim_start + self.stim_duration:
|
||||
continue
|
||||
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
|
||||
ax.plot(time, frequency)
|
||||
start_idx, end_idx = hF.time_window_detect_f_baseline(time[0], self.stim_start, sampling_interval)
|
||||
ax.plot((time[start_idx], time[end_idx]), (self.f_baseline_frequencies[i], self.f_baseline_frequencies[i]),
|
||||
label="f_base", color="deepskyblue")
|
||||
|
||||
start_idx, end_idx = hF.time_window_detect_f_infinity(time[0], self.stim_start, self.stim_duration, sampling_interval)
|
||||
ax.plot((time[start_idx], time[end_idx]), (self.f_inf_frequencies[i], self.f_inf_frequencies[i]),
|
||||
label="f_inf", color="limegreen")
|
||||
|
||||
start_idx, end_idx = hF.time_window_detect_f_zero(time[0], self.stim_start, sampling_interval)
|
||||
ax.plot((time[start_idx], time[end_idx]), (self.f_zero_frequencies[i], self.f_zero_frequencies[i]),
|
||||
label="f_zero", color="orange")
|
||||
|
||||
plt.legend()
|
||||
if save_path is not None:
|
||||
plt.savefig(save_path + "/detections_contrast_{:.2f}.png".format(c))
|
||||
else:
|
||||
plt.show()
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def get_fi_curve_class(data, stimulus_values, eod_freq=None, trials=5, save_dir=None, recalculate=False) -> FICurve:
|
||||
if isinstance(data, CellData):
|
||||
return FICurveCellData(data, stimulus_values, save_dir, recalculate)
|
||||
if isinstance(data, LifacNoiseModel):
|
||||
if eod_freq is None:
|
||||
raise ValueError("The FiCurveModel needs the eod variable to work")
|
||||
return FICurveModel(data, stimulus_values, eod_freq, trials=trials)
|
||||
|
||||
raise ValueError("Unknown type: Cannot find corresponding Baseline class. Data was type:" + str(type(data)))
|
||||
33
experiments/Sam.py
Normal file
33
experiments/Sam.py
Normal file
@@ -0,0 +1,33 @@
|
||||
|
||||
from parser.CellData import CellData
|
||||
from models.LIFACnoise import LifacNoiseModel
|
||||
|
||||
|
||||
class SamAnalysis:
|
||||
pass
|
||||
|
||||
|
||||
class SamAnalysisData(SamAnalysis):
|
||||
|
||||
def __init__(self, cell_data):
|
||||
self.cell_data = cell_data
|
||||
|
||||
self.mean_mod_freq_responses = []
|
||||
|
||||
|
||||
class SamAnalysisModel(SamAnalysis):
|
||||
|
||||
def __init__(self, model):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def get_sam_class(data) -> SamAnalysis:
|
||||
if isinstance(data, CellData):
|
||||
return SamAnalysisData(data)
|
||||
if isinstance(data, LifacNoiseModel):
|
||||
return SamAnalysisModel(data)
|
||||
|
||||
raise ValueError("Unknown type: Cannot find corresponding SamAnalysis class. data was type:" + str(type(data)))
|
||||
0
experiments/__init__.py
Normal file
0
experiments/__init__.py
Normal file
Reference in New Issue
Block a user