diff --git a/AdaptionCurrent.py b/AdaptionCurrent.py new file mode 100644 index 0000000..97e54d2 --- /dev/null +++ b/AdaptionCurrent.py @@ -0,0 +1,168 @@ + +from FiCurve import FICurve +from CellData import CellData +import matplotlib.pyplot as plt +from scipy.optimize import curve_fit +import os +import numpy as np +import functions as fu + + +class Adaption: + + def __init__(self, cell_data: CellData, fi_curve: FICurve = None): + self.cell_data = cell_data + if fi_curve is None: + self.fi_curve = FICurve(cell_data) + else: + self.fi_curve = fi_curve + + # [[a, tau_eff, c], [], [a, tau_eff, c], ...] + self.exponential_fit_vars = [] + self.tau_real = [] + + self.fit_exponential() + self.calculate_tau_from_tau_eff() + + def fit_exponential(self, length_of_fit=0.05): + mean_frequencies = self.cell_data.get_mean_isi_frequencies() + time_axes = self.cell_data.get_time_axes_mean_frequencies() + for i in range(len(mean_frequencies)): + start_idx = self.__find_start_idx_for_exponential_fit(i) + + if start_idx == -1: + self.exponential_fit_vars.append([]) + continue + + # shorten length of fit to stay in stimulus region if given length is too long + sampling_interval = self.cell_data.get_sampling_interval() + used_length_of_fit = length_of_fit + if (start_idx * sampling_interval) - self.cell_data.get_delay() + length_of_fit > self.cell_data.get_stimulus_end(): + print(start_idx * sampling_interval, "start - end", start_idx * sampling_interval + length_of_fit) + print("Shortened length of fit to keep it in the stimulus region!") + used_length_of_fit = self.cell_data.get_stimulus_end() - (start_idx * sampling_interval) + + end_idx = start_idx + int(used_length_of_fit/sampling_interval) + y_values = mean_frequencies[i][start_idx:end_idx+1] + x_values = time_axes[i][start_idx:end_idx+1] + + tau = self.__approximate_tau_for_exponential_fit(x_values, y_values, i) + + # start the actual fit: + try: + p0 = (self.fi_curve.f_zeros[i], tau, self.fi_curve.f_infinities[i]) + popt, pcov = curve_fit(fu.exponential_function, x_values, y_values, + p0=p0, maxfev=10000, bounds=([-np.inf, 0, -np.inf], [np.inf, np.inf, np.inf])) + except RuntimeError: + print("RuntimeError happened in fit_exponential.") + self.exponential_fit_vars.append([]) + continue + + # Obviously a bad fit - time constant, expected in range 3-10ms, has value over 1 second or is negative + if abs(popt[1] > 1) or popt[1] < 0: + self.exponential_fit_vars.append([]) + else: + self.exponential_fit_vars.append(popt) + + def __approximate_tau_for_exponential_fit(self, x_values, y_values, mean_freq_idx): + if self.fi_curve.f_infinities[mean_freq_idx] < self.fi_curve.f_baselines[mean_freq_idx] * 0.95: + test_val = [y > 0.65 * self.fi_curve.f_infinities[mean_freq_idx] for y in y_values] + else: + test_val = [y < 0.65 * self.fi_curve.f_zeros[mean_freq_idx] for y in y_values] + + try: + idx = test_val.index(True) + if idx == 0: + idx = 1 + tau = x_values[idx] - x_values[0] + except ValueError: + tau = x_values[-1] - x_values[0] + + return tau + + def __find_start_idx_for_exponential_fit(self, mean_freq_idx): + stimulus_start_idx = int((self.cell_data.get_delay() + self.cell_data.get_stimulus_start()) / self.cell_data.get_sampling_interval()) + if self.fi_curve.f_infinities[mean_freq_idx] > self.fi_curve.f_baselines[mean_freq_idx] * 1.1: + # start setting starting variables for the fit + # search for the start_index by searching for the max + j = 0 + while True: + try: + if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j] == self.fi_curve.f_zeros[mean_freq_idx]: + start_idx = stimulus_start_idx + j + break + except IndexError as e: + return -1 + + j += 1 + + elif self.fi_curve.f_infinities[mean_freq_idx] < self.fi_curve.f_baselines[mean_freq_idx] * 0.9: + # start setting starting variables for the fit + # search for start by finding the end of the minimum + found_min = False + j = int(0.05 / self.cell_data.get_sampling_interval()) + nothing_to_fit = False + while True: + if not found_min: + if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j] == self.fi_curve.f_zeros[mean_freq_idx]: + found_min = True + else: + if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j + 1] > self.fi_curve.f_zeros[mean_freq_idx]: + start_idx = stimulus_start_idx + j + break + if j > 0.1 / self.cell_data.get_sampling_interval(): + # no rise in freq until to close to the end of the stimulus (to little place to fit) + return -1 + j += 1 + + if nothing_to_fit: + return -1 + else: + # there is nothing to fit to: + return -1 + + return start_idx + + def calculate_tau_from_tau_eff(self): + taus = [] + for i in range(len(self.exponential_fit_vars)): + if len(self.exponential_fit_vars[i]) == 0: + continue + tau_eff = self.exponential_fit_vars[i][1]*1000 # tau_eff in ms + # intensity = self.fi_curve.stimulus_value[i] + f_infinity_slope = self.fi_curve.get_f_infinity_slope() + fi_curve_slope = self.fi_curve.get_fi_curve_slope_of_straight() + + taus.append(tau_eff*(fi_curve_slope/f_infinity_slope)) + # print((fi_curve_slope/f_infinity_slope)) + # print(tau_eff*(fi_curve_slope/f_infinity_slope), "=", tau_eff, "*", (fi_curve_slope/f_infinity_slope)) + + self.tau_real = np.median(taus) + + def plot_exponential_fits(self, save_path: str = None, indices: list = None, delete_previous: bool = False): + if delete_previous: + for val in self.cell_data.get_fi_contrasts(): + + prev_path = save_path + "mean_freq_exp_fit_contrast:" + str(round(val, 3)) + ".png" + + if os.path.exists(prev_path): + os.remove(prev_path) + + for i in range(len(self.cell_data.get_fi_contrasts())): + if self.exponential_fit_vars[i] == []: + continue + + plt.plot(self.cell_data.get_time_axes_mean_frequencies()[i], self.cell_data.get_mean_isi_frequencies()[i]) + vars = self.exponential_fit_vars[i] + fit_x = np.arange(0, 0.4, self.cell_data.get_sampling_interval()) + plt.plot(fit_x, [fu.exponential_function(x, vars[0], vars[1], vars[2]) for x in fit_x]) + plt.ylim([0, max(self.fi_curve.f_zeros[i], self.fi_curve.f_baselines[i])*1.1]) + plt.xlabel("Time [s]") + plt.ylabel("Frequency [Hz]") + + if save_path is None: + plt.show() + else: + plt.savefig(save_path + "mean_freq_exp_fit_contrast:" + str(round(self.cell_data.get_fi_contrasts()[i], 3)) + ".png") + + plt.close() \ No newline at end of file diff --git a/CellData.py b/CellData.py new file mode 100644 index 0000000..1eb7661 --- /dev/null +++ b/CellData.py @@ -0,0 +1,156 @@ + +import DataParserFactory as dpf +from warnings import warn +from os import listdir +import helperFunctions as hf +import numpy as np + + +def icelldata_of_dir(base_path): + for item in sorted(listdir(base_path)): + item_path = base_path + item + + try: + yield CellData(item_path) + except TypeError as e: + warn_msg = str(e) + warn(warn_msg) + + +class CellData: + # Class to capture all the data of a single cell across all experiments (base rate, FI-curve, .?.) + # should be abstract from the way the data is saved in the background .dat vs .nix + + # traces list of lists with traces: [[time], [voltage (v1)], [EOD], [local eod], [stimulus]] + TIME = 0 + V1 = 1 + EOD = 2 + LOCAL_EOD = 3 + STIMULUS = 4 + + def __init__(self, data_path): + self.data_path = data_path + self.base_traces = None + # self.fi_traces = None + self.fi_intensities = None + self.fi_spiketimes = None + self.fi_trans_amplitudes = None + self.mean_isi_frequencies = None + self.time_axes = None + # self.metadata = None + self.parser = dpf.get_parser(data_path) + + self.sampling_interval = self.parser.get_sampling_interval() + self.recording_times = self.parser.get_recording_times() + + def get_data_path(self): + return self.data_path + + def get_base_traces(self, trace_type=None): + if self.base_traces is None: + self.base_traces = self.parser.get_baseline_traces() + + if trace_type is None: + return self.base_traces + else: + return self.base_traces[trace_type] + + def get_fi_traces(self): + raise NotImplementedError("CellData:get_fi_traces():\n" + + "Getting the Fi-Traces currently overflows the RAM and causes swapping! Reimplement if really needed!") + # if self.fi_traces is None: + # self.fi_traces = self.parser.get_fi_curve_traces() + # return self.fi_traces + + def get_fi_spiketimes(self): + self.__read_fi_spiketimes_info__() + return self.fi_spiketimes + + def get_fi_intensities(self): + self.__read_fi_spiketimes_info__() + return self.fi_intensities + + def get_fi_contrasts(self): + self.__read_fi_spiketimes_info__() + contrast = [] + for i in range(len(self.fi_intensities)): + contrast.append((self.fi_intensities[i] - self.fi_trans_amplitudes[i]) / self.fi_trans_amplitudes[i]) + + return contrast + + def get_mean_isi_frequencies(self): + if self.mean_isi_frequencies is None: + self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequencies(self.get_fi_spiketimes(), + self.get_time_start(), + self.get_sampling_interval()) + return self.mean_isi_frequencies + + def get_time_axes_mean_frequencies(self): + if self.time_axes is None: + self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequencies(self.get_fi_spiketimes(), + self.get_time_start(), + self.get_sampling_interval()) + return self.time_axes + + def get_base_frequency(self): + base_freqs = [] + for freq in self.get_mean_isi_frequencies(): + delay = self.get_delay() + sampling_interval = self.get_sampling_interval() + if delay < 0.1: + warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.") + + idx_start = int(0.025 / sampling_interval) + idx_end = int((delay - 0.025) / sampling_interval) + base_freqs.append(np.mean(freq[idx_start:idx_end])) + + return np.median(base_freqs) + + def get_sampling_interval(self) -> float: + return self.sampling_interval + + def get_recording_times(self) -> list: + return self.recording_times + + def get_time_start(self) -> float: + return self.recording_times[0] + + def get_delay(self) -> float: + return abs(self.recording_times[0]) + + def get_time_end(self) -> float: + return self.recording_times[2] + self.recording_times[3] + + def get_stimulus_start(self) -> float: + return self.recording_times[1] + + def get_stimulus_duration(self) -> float: + return self.recording_times[2] + + def get_stimulus_end(self) -> float: + return self.get_stimulus_start() + self.get_stimulus_duration() + + def get_after_stimulus_duration(self) -> float: + return self.recording_times[3] + + def __read_fi_spiketimes_info__(self): + if self.fi_spiketimes is None: + trans_amplitudes, intensities, spiketimes = self.parser.get_fi_curve_spiketimes() + + self.fi_intensities, self.fi_spiketimes, self.fi_trans_amplitudes = hf.merge_similar_intensities(intensities, spiketimes, trans_amplitudes) + + # def get_metadata(self): + # self.__read_metadata__() + # return self.metadata + # + # def get_metadata_item(self, item): + # self.__read_metadata__() + # if item in self.metadata.keys(): + # return self.metadata[item] + # else: + # raise KeyError("CellData:get_metadata_item: Item not found in metadata! - " + str(item)) + # + # def __read_metadata__(self): + # if self.metadata is None: + # # TODO!! + # pass diff --git a/DataParserFactory.py b/DataParserFactory.py new file mode 100644 index 0000000..3ae7cea --- /dev/null +++ b/DataParserFactory.py @@ -0,0 +1,237 @@ + +from os.path import isdir, exists +from warnings import warn +import pyrelacs.DataLoader as Dl + +UNKNOWN = -1 +DAT_FORMAT = 0 +NIX_FORMAT = 1 + + +class AbstractParser: + + def cell_get_metadata(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_baseline_traces(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_fi_curve_traces(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_fi_curve_spiketimes(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_sampling_interval(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_recording_times(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + +class DatParser(AbstractParser): + + def __init__(self, dir_path): + self.base_path = dir_path + self.fi_file = self.base_path + "/fispikes1.dat" + self.stimuli_file = self.base_path + "/stimuli.dat" + self.__test_data_file_existence__() + + self.fi_recording_times = [] + self.sampling_interval = -1 + + def cell_get_metadata(self): + pass + + def get_sampling_interval(self): + if self.sampling_interval == -1: + self.__read_sampling_interval__() + + return self.sampling_interval + + def get_recording_times(self): + if len(self.fi_recording_times) == 0: + self.__read_fi_recording_times__() + return self.fi_recording_times + + def get_baseline_traces(self): + return self.__get_traces__("BaselineActivity") + + def get_fi_curve_traces(self): + return self.__get_traces__("FICurve") + + # TODO clean up/ rewrite + def get_fi_curve_spiketimes(self): + spiketimes = [] + pre_intensities = [] + pre_durations = [] + intensities = [] + trans_amplitudes = [] + pre_duration = -1 + index = -1 + skip = False + trans_amplitude = float('nan') + for metadata, key, data in Dl.iload(self.fi_file): + if len(metadata) != 0: + + metadata_index = 0 + + if '----- Control --------------------------------------------------------' in metadata[0].keys(): + metadata_index = 1 + pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2]) + trans_amplitude = float(metadata[0]["trans. amplitude"][:-2]) + if pre_duration == 0: + skip = False + else: + skip = True + continue + + if skip: + continue + + intensity = float(metadata[metadata_index]['intensity'][:-2]) + pre_intensity = float(metadata[metadata_index]['preintensity'][:-2]) + + intensities.append(intensity) + pre_durations.append(pre_duration) + pre_intensities.append(pre_intensity) + trans_amplitudes.append(trans_amplitude) + spiketimes.append([]) + index += 1 + + if skip: + continue + + if data.shape[1] != 1: + raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!") + + spike_time_data = data[:, 0]/1000 + if len(spike_time_data) < 10: + continue + if spike_time_data[-1] < 1: + print("# ignoring spike-train that ends before one second.") + continue + + spiketimes[index].append(spike_time_data) + + # TODO add merging for similar intensities? hf.merge_similar_intensities() + trans_amplitudes + + return trans_amplitudes, intensities, spiketimes + + def __get_traces__(self, repro): + time_traces = [] + v1_traces = [] + eod_traces = [] + local_eod_traces = [] + stimulus_traces = [] + + nothing = True + + for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro): + nothing = False + time_traces.append(time) + v1_traces.append(x[0]) + eod_traces.append(x[1]) + local_eod_traces.append(x[2]) + stimulus_traces.append(x[3]) + + traces = [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] + + if nothing: + warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!" + warn(warn_msg) + + return traces + + def __read_fi_recording_times__(self): + + delays = [] + stim_duration = [] + pause = [] + + for metadata, key, data in Dl.iload(self.fi_file): + if len(metadata) != 0: + control_key = '----- Control --------------------------------------------------------' + if control_key in metadata[0].keys(): + delays.append(float(metadata[0][control_key]["delay"][:-2])/1000) + pause.append(float(metadata[0][control_key]["pause"][:-2])/1000) + stim_key = "----- Test-Intensities -----------------------------------------------" + stim_duration.append(float(metadata[0][stim_key]["duration"][:-2])/1000) + + for l in [delays, stim_duration, pause]: + if len(l) == 0: + raise RuntimeError("DatParser:__read_fi_recording_times__:\n" + + "Couldn't find any delay, stimulus duration and or pause in the metadata.\n" + + "In file:" + self.base_path) + elif len(set(l)) != 1: + raise RuntimeError("DatParser:__read_fi_recording_times__:\n" + + "Found multiple different delay, stimulus duration and or pause in the metadata.\n" + + "In file:" + self.base_path) + else: + self.fi_recording_times = [-delays[0], 0, stim_duration[0], pause[0] - delays[0]] + + def __read_sampling_interval__(self): + stop = False + sampling_intervals = [] + for metadata, key, data in Dl.iload(self.stimuli_file): + for md in metadata: + for i in range(4): + key = "sample interval" + str(i+1) + if key in md.keys(): + + sampling_intervals.append(float(md[key][:-2]) / 1000) + stop = True + else: + break + + if stop: + break + + if len(sampling_intervals) == 0: + raise RuntimeError("DatParser:__read_sampling_interval__:\n" + + "Sampling intervals not found in stimuli.dat this is not handled!\n" + + "with File:" + self.base_path) + + if len(set(sampling_intervals)) != 1: + raise RuntimeError("DatParser:__read_sampling_interval__:\n" + + "Sampling intervals not the same for all traces this is not handled!\n" + + "with File:" + self.base_path) + else: + self.sampling_interval = sampling_intervals[0] + + def __test_data_file_existence__(self): + if not exists(self.stimuli_file): + raise RuntimeError(self.stimuli_file + " file doesn't exist!") + if not exists(self.fi_file): + raise RuntimeError(self.fi_file + " file doesn't exist!") + + +# TODO #################################### +class NixParser(AbstractParser): + + def __init__(self, nix_file_path): + self.file_path = nix_file_path + warn("NIX PARSER: NOT YET IMPLEMENTED!") +# TODO #################################### + + +def get_parser(data_path: str) -> AbstractParser: + data_format = __test_for_format__(data_path) + + if data_format == DAT_FORMAT: + return DatParser(data_path) + elif data_format == NIX_FORMAT: + return NixParser(data_path) + elif UNKNOWN: + raise TypeError("DataParserFactory:get_parser(data_path):\nCannot determine type of data for:" + data_path) + + +def __test_for_format__(data_path): + if isdir(data_path): + if exists(data_path + "/fispikes1.dat"): + return DAT_FORMAT + + elif data_path.endswith(".nix"): + return NIX_FORMAT + else: + return UNKNOWN diff --git a/FiCurve.py b/FiCurve.py new file mode 100644 index 0000000..3903518 --- /dev/null +++ b/FiCurve.py @@ -0,0 +1,153 @@ + +from CellData import CellData +import numpy as np +from scipy.optimize import curve_fit +import matplotlib.pyplot as plt +from warnings import warn +import functions as fu + + +class FICurve: + + def __init__(self, cell_data: CellData, contrast: bool = True): + self.cell_data = cell_data + self.using_contrast = contrast + + if contrast: + self.stimulus_value = cell_data.get_fi_contrasts() + else: + self.stimulus_value = cell_data.get_fi_intensities() + + self.f_zeros = [] + self.f_infinities = [] + self.f_baselines = [] + + # f_max, f_min, k, x_zero + self.boltzmann_fit_vars = [] + # offset increase + self.f_infinity_fit = [] + + self.all_calculate_frequency_points() + self.fit_line() + self.fit_boltzmann() + + def all_calculate_frequency_points(self): + mean_frequencies = self.cell_data.get_mean_isi_frequencies() + if len(mean_frequencies) == 0: + warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n" + "Was all_calculate_mean_isi_frequencies already called?") + + for freq in mean_frequencies: + self.f_zeros.append(self.__calculate_f_zero__(freq)) + self.f_baselines.append(self.__calculate_f_baseline__(freq)) + self.f_infinities.append(self.__calculate_f_infinity__(freq)) + + def fit_line(self): + popt, pcov = curve_fit(fu.clipped_line, self.stimulus_value, self.f_infinities) + self.f_infinity_fit = popt + + def fit_boltzmann(self): + max_f0 = float(max(self.f_zeros)) + min_f0 = float(min(self.f_zeros)) + mean_int = float(np.mean(self.stimulus_value)) + + total_increase = max_f0 - min_f0 + total_change_int = max(self.stimulus_value) - min(self.stimulus_value) + start_k = float((total_increase / total_change_int * 4) / max_f0) + + popt, pcov = curve_fit(fu.full_boltzmann, self.stimulus_value, self.f_zeros, + p0=(max_f0, min_f0, start_k, mean_int), + maxfev=10000, bounds=([0, 0, -np.inf, -np.inf], [3000, 3000, np.inf, np.inf])) + + self.boltzmann_fit_vars = popt + + def plot_fi_curve(self, savepath: str = None): + min_x = min(self.stimulus_value) + max_x = max(self.stimulus_value) + step = (max_x - min_x) / 5000 + x_values = np.arange(min_x, max_x, step) + + plt.plot(self.stimulus_value, self.f_baselines, color='blue', label='f_base') + + plt.plot(self.stimulus_value, self.f_infinities, 'o', color='lime', label='f_inf') + plt.plot(x_values, [fu.clipped_line(x, self.f_infinity_fit[0], self.f_infinity_fit[1]) for x in x_values], + color='darkgreen', label='f_inf_fit') + + plt.plot(self.stimulus_value, self.f_zeros, 'o', color='orange', label='f_zero') + popt = self.boltzmann_fit_vars + plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values], + color='red', label='f_0_fit') + + plt.legend() + plt.ylabel("Frequency [Hz]") + if self.using_contrast: + plt.xlabel("Stimulus contrast") + else: + plt.xlabel("Stimulus intensity [mv]") + if savepath is None: + plt.show() + else: + plt.savefig(savepath + "fi_curve.png") + plt.close() + + def __calculate_f_baseline__(self, frequency, buffer=0.025): + delay = self.cell_data.get_delay() + sampling_interval = self.cell_data.get_sampling_interval() + if delay < 0.1: + warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.") + + idx_start = int(buffer/sampling_interval) + idx_end = int((delay-buffer)/sampling_interval) + return np.mean(frequency[idx_start:idx_end]) + + def __calculate_f_zero__(self, frequency, length_of_mean=0.1, buffer=0.025): + stimulus_start = self.cell_data.get_delay() + self.cell_data.get_stimulus_start() + sampling_interval = self.cell_data.get_sampling_interval() + + start_idx = int((stimulus_start - buffer) / sampling_interval) + end_idx = int((stimulus_start + buffer*2) / sampling_interval) + + freq_before = frequency[start_idx-(int(length_of_mean/sampling_interval)):start_idx] + fb_mean = np.mean(freq_before) + fb_std = np.std(freq_before) + + peak_frequency = fb_mean + count = 0 + for i in range(start_idx + 1, end_idx): + if fb_mean-3*fb_std <= frequency[i] <= fb_mean+3*fb_std: + continue + + if abs(frequency[i] - fb_mean) > abs(peak_frequency - fb_mean): + peak_frequency = frequency[i] + count += 1 + + return peak_frequency + + def __calculate_f_infinity__(self, frequency, length=0.2, buffer=0.025): + stimulus_end_time = \ + self.cell_data.get_delay() + self.cell_data.get_stimulus_start() + self.cell_data.get_stimulus_duration() + + start_idx = int((stimulus_end_time - length - buffer) / self.cell_data.get_sampling_interval()) + end_idx = int((stimulus_end_time - buffer) / self.cell_data.get_sampling_interval()) + + return np.mean(frequency[start_idx:end_idx]) + + def get_f_zero_inverse_at_frequency(self, frequency): + b_vars = self.boltzmann_fit_vars + return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3]) + + def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value): + infty_vars = self.f_infinity_fit + return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1]) + + + def get_f_infinity_slope(self): + return self.f_infinity_fit[1] + + def get_fi_curve_slope_at(self, stimulus_value): + fit_vars = self.boltzmann_fit_vars + return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) + + def get_fi_curve_slope_of_straight(self): + fit_vars = self.boltzmann_fit_vars + return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) diff --git a/functionalityTests.py b/functionalityTests.py new file mode 100644 index 0000000..25b14fd --- /dev/null +++ b/functionalityTests.py @@ -0,0 +1,44 @@ + +from models.LIFAC import LIFACModel +from stimuli.StepStimulus import StepStimulus +import numpy as np +import matplotlib.pyplot as plt +import functions as fu + + +def test_lifac(): + model = LIFACModel() + stimulus = StepStimulus(0.5, 1, 15) + + for step_size in [0.001, 0.1]: + model.set_variable("step_size", step_size) + + v, spiketimes = model(stimulus, 2) + + plt.plot(np.arange(0, 2, step_size/1000), v, label="step_size:" + str(step_size)) + + plt.xlabel("time in seconds") + plt.ylabel("Voltage") + plt.title("Voltage in the LIFAC-Model with different step sizes") + plt.show() + plt.close() + + +def test_plot_inverses(ficurve): + var = ficurve.boltzmann_fit_vars + + fig, ax1 = plt.subplots(1, 1, figsize=(4.5, 4.5), tight_layout=True) + start = min(ficurve.stimulus_value) + end = max(ficurve.stimulus_value) + x_values = np.arange(start, end, (end-start)/5000) + ax1.plot(x_values, [fu.full_boltzmann(x, var[0], var[1], var[2], var[3]) for x in x_values], label="fit") + ax1.set_ylabel('freq') + ax1.set_xlabel('stimulus') + + start = var[1] + end = var[0] + x_values = np.arange(start, end, (end - start) / 50) + ax1.plot([fu.inverse_full_boltzmann(x, var[0], var[1], var[2], var[3]) for x in x_values], x_values, + '.', c="red", label='inverse') + plt.legend() + plt.show() \ No newline at end of file diff --git a/functions.py b/functions.py new file mode 100644 index 0000000..624ff2f --- /dev/null +++ b/functions.py @@ -0,0 +1,40 @@ + +import numpy as np + + +def exponential_function(x, a, b, c): + return (a-c)*np.exp(-x/b)+c + + +def upper_boltzmann(x, f_max, k, x_zero): + return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None) + + +def full_boltzmann(x, f_max, f_min, k, x_zero): + return (f_max-f_min) * (1 / (1 + np.power(np.e, -k * (x - x_zero)))) + f_min + + +def full_boltzmann_straight_slope(f_max, f_min, k, x_zero=0): + return (f_max-f_min)*k*1/2 + + +def derivative_full_boltzmann(x, f_max, f_min, k, x_zero): + return (f_max - f_min) * k * np.power(np.e, -k * (x - x_zero)) / (1 + np.power(np.e, -k * (x - x_zero))**2) + + +def inverse_full_boltzmann(x, f_max, f_min, k, x_zero): + if x < f_min or x > f_max: + raise ValueError("Value undefined in inverse_full_boltzmann") + + return -(np.log((f_max-f_min) / (x - f_min) - 1) / k) + x_zero + + +def clipped_line(x, a, b): + return np.clip(a+b*x, 0, None) + + +def inverse_clipped_line(x, a, b): + if clipped_line(x, a, b) == 0: + raise ValueError("Value undefined in inverse_clipped_line.") + + return (x-a)/b diff --git a/generalTests.py b/generalTests.py new file mode 100644 index 0000000..50900fd --- /dev/null +++ b/generalTests.py @@ -0,0 +1,54 @@ +import numpy as np +import matplotlib.pyplot as plt +from models.LeakyIntegrateFireModel import LIFModel + + +# def calculate_step(current_v, tau, i_b, step_size=0.01): + # return current_v + (step_size * (-current_v + mem_res * i_b)) / tau + +def function_e(x): + return (0-15) * np.e**(-x/1) + 15 + +# x_values = np.arange(0, 5, 0.01) +# plt.plot(x_values, [function_e(x) for x in x_values]) +# plt.show() + + +# def function_f(i_base, tau=1, threshold=10, reset=0): +# return -1/(tau*np.log((threshold-i_base)/(reset - i_base))) +# +# x_values = np.arange(0, 20, 0.001) +# plt.plot(x_values, [function_f(x) for x in x_values]) +# plt.show() + + +# LIF test: +# Rm = 100 MOhm, Cm = 200pF +step_size = 0.01 # ms +mem_res = 100*1000000 +tau = 1 +base_freq = 30 +v_threshold = 10 +base_input = -(- v_threshold / (np.e**(-1/(base_freq*tau))) + 1) / mem_res + +stim1 = int(1000/step_size) * [base_input] + +stimulus = [] +stimulus.extend(stim1) + +lif = LIFModel(mem_res, tau, 0, 0, stimulus, 10) + +voltage, spikes_b = lif.calculate_response() +y_spikes = [] +x_spikes = [] +for i in range(len(spikes_b)): + if spikes_b[i]: + y_spikes.append(10.5) + x_spikes.append(i*step_size) + +time = np.arange(0, 1000, step_size) +plt.plot(time, voltage) +plt.plot(x_spikes, y_spikes, 'o') +plt.show() +plt.close() + diff --git a/helperFunctions.py b/helperFunctions.py new file mode 100644 index 0000000..e17c063 --- /dev/null +++ b/helperFunctions.py @@ -0,0 +1,214 @@ +import os +import pyrelacs.DataLoader as dl +import numpy as np +import matplotlib.pyplot as plt +from warnings import warn + +def get_subfolder_paths(basepath): + subfolders = [] + for content in os.listdir(basepath): + content_path = basepath + content + if os.path.isdir(content_path): + subfolders.append(content_path) + + return sorted(subfolders) + + +def get_traces(directory, trace_type, repro): + # trace_type = 1: Voltage p-unit + # trace_type = 2: EOD + # trace_type = 3: local EOD ~(EOD + stimulus) + # trace_type = 4: Stimulus + + load_iter = dl.iload_traces(directory, repro=repro) + + time_traces = [] + value_traces = [] + + nothing = True + + for info, key, time, x in load_iter: + nothing = False + time_traces.append(time) + value_traces.append(x[trace_type-1]) + + if nothing: + print("iload_traces found nothing for the BaselineActivity repro!") + + return time_traces, value_traces + + +def get_all_traces(directory, repro): + load_iter = dl.iload_traces(directory, repro=repro) + + time_traces = [] + v1_traces = [] + eod_traces = [] + local_eod_traces = [] + stimulus_traces = [] + + nothing = True + + for info, key, time, x in load_iter: + nothing = False + time_traces.append(time) + v1_traces.append(x[0]) + eod_traces.append(x[1]) + local_eod_traces.append(x[2]) + stimulus_traces.append(x[3]) + print(info) + + traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces] + + if nothing: + print("iload_traces found nothing for the BaselineActivity repro!") + + return time_traces, traces + + +def merge_similar_intensities(intensities, spiketimes, trans_amplitudes): + i = 0 + + diffs = np.diff(sorted(intensities)) + margin = np.mean(diffs) * 0.6666 + + while True: + if i >= len(intensities): + break + intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin) + i += 1 + + # Sort the lists so that intensities are increasing + x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))] + intensities = x[0] + spiketimes = x[1] + + return intensities, spiketimes, trans_amplitudes + + +def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin): + intensity = intensities[index] + + indices_to_merge = [] + for i in range(index+1, len(intensities)): + if np.abs(intensities[i]-intensity) < margin: + indices_to_merge.append(i) + + if len(indices_to_merge) != 0: + indices_to_merge.reverse() + + trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge] + + all_the_same = True + for j in range(1, len(trans_amplitude_values)): + if not trans_amplitude_values[0] == trans_amplitude_values[j]: + all_the_same = False + break + + if all_the_same: + for idx in indices_to_merge: + del trans_amplitudes[idx] + else: + raise RuntimeError("Trans_amplitudes not the same....") + for idx in indices_to_merge: + spiketimes[index].extend(spiketimes[idx]) + del spiketimes[idx] + del intensities[idx] + + return intensities, spiketimes, trans_amplitudes + + +def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval): + times = [] + mean_frequencies = [] + + for i in range(len(spiketimes)): + trial_times = [] + trial_means = [] + for j in range(len(spiketimes[i])): + time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval) + trial_means.append(isi_freq) + trial_times.append(time) + + time, mean_freq = calculate_mean_frequency(trial_times, trial_means) + times.append(time) + mean_frequencies.append(mean_freq) + + return times, mean_frequencies + + +def calculate_isi_frequency(spiketimes, time_start, sampling_interval): + first_isi = spiketimes[0] - time_start + isis = [first_isi] + isis.extend(np.diff(spiketimes)) + time = np.arange(time_start, spiketimes[-1], sampling_interval) + + full_frequency = [] + i = 0 + for isi in isis: + if isi == 0: + warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()") + continue + freq = 1 / isi + frequency_step = int(round(isi * (1 / sampling_interval))) * [freq] + full_frequency.extend(frequency_step) + i += 1 + if len(full_frequency) != len(time): + if abs(len(full_frequency) - len(time)) == 1: + warn("FiCurve:__calculate_mean_isi_frequency__():\nFrequency and time were one of in length!") + if len(full_frequency) < len(time): + time = time[:len(full_frequency)] + else: + full_frequency = full_frequency[:len(time)] + else: + print("ERROR PRINT:") + print("freq:", len(full_frequency), "time:", len(time), "diff:", len(full_frequency) - len(time)) + raise RuntimeError("FiCurve:__calculate_mean_isi_frequency__():\n" + "Frequency and time are not the same length!") + + return time, full_frequency + + +def calculate_mean_frequency(trial_times, trial_freqs): + lengths = [len(t) for t in trial_times] + shortest = min(lengths) + + time = trial_times[0][0:shortest] + shortend_freqs = [freq[0:shortest] for freq in trial_freqs] + mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)] + + return time, mean_freq + + +def crappy_smoothing(signal:list, window_size:int = 5) -> list: + smoothed = [] + + for i in range(len(signal)): + k = window_size + if i < window_size: + k = i + j = window_size + if i + j > len(signal): + j = len(signal) - i + + smoothed.append(np.mean(signal[i-k:i+j])) + + return smoothed + + +def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None): + contrast = cell_data.get_fi_contrasts() + time_axes = cell_data.get_time_axes_mean_frequencies() + mean_freqs = cell_data.get_mean_isi_frequencies() + + if indices is None: + indices = np.arange(len(contrast)) + + for i in indices: + plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2))) + + if save_path is None: + plt.show() + else: + plt.savefig(save_path + "mean_frequency_curves.png") + plt.close() \ No newline at end of file diff --git a/introduction/introductionBaseline.py b/introduction/introductionBaseline.py new file mode 100644 index 0000000..b28a485 --- /dev/null +++ b/introduction/introductionBaseline.py @@ -0,0 +1,283 @@ + +import pyrelacs.DataLoader as dl +import numpy as np +import matplotlib.pyplot as plt +from IPython import embed +import os +import helperFunctions as hf +from thunderfish.eventdetection import detect_peaks + + +SAVEPATH = "" + + +def get_savepath(): + global SAVEPATH + return SAVEPATH + + +def set_savepath(new_path): + global SAVEPATH + SAVEPATH = new_path + + +def main(): + for folder in hf.get_subfolder_paths("data/"): + filepath = folder + "/basespikes1.dat" + set_savepath("figures/" + folder.split('/')[1] + "/") + + print("Folder:", folder) + + if not os.path.exists(get_savepath()): + os.makedirs(get_savepath()) + + spiketimes = [] + + ran = False + for metadata, key, data in dl.iload(filepath): + ran = True + spikes = data[:, 0] + spiketimes.append(spikes) # save for calculation of vector strength + metadata = metadata[0] + #print(metadata) + # print('firing frequency1:', metadata['firing frequency1']) + # print(mean_firing_rate(spikes)) + + # print('Coefficient of Variation (CV):', metadata['CV1']) + # print(calculate_coefficient_of_variation(spikes)) + + if not ran: + print("------------ DIDN'T RUN") + + isi_histogram(spiketimes) + + times, eods = hf.get_traces(folder, 2, 'BaselineActivity') + times, v1s = hf.get_traces(folder, 1, 'BaselineActivity') + + vs = calculate_vector_strength(times, eods, spiketimes, v1s) + + # print("Calculated vector strength:", vs) + + +def mean_firing_rate(spiketimes): + # mean firing rate (number of spikes per time) + return len(spiketimes)/spiketimes[-1]*1000 + + +def calculate_coefficient_of_variation(spiketimes): + # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) + isi = np.diff(spiketimes) + std = np.std(isi) + mean = np.mean(isi) + + return std/mean + + +def isi_histogram(spiketimes): + # ISI histogram (play around with binsize! < 1ms) + + isi = [] + for spike_list in spiketimes: + isi.extend(np.diff(spike_list)) + maximum = max(isi) + bins = np.arange(0, maximum*1.01, 0.1) + + plt.title('Phase locking of ISI without stimulus') + plt.xlabel('ISI in ms') + plt.ylabel('Count') + plt.hist(isi, bins=bins) + plt.savefig(get_savepath() + 'phase_locking_without_stimulus.png') + plt.close() + + +def calculate_vector_strength(times, eods, spiketimes, v1s): + # Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8 + # dl.iload_traces(repro='BaselineActivity') + + relative_spike_times = [] + eod_durations = [] + + if len(times) == 0: + print("-----LENGTH OF TIMES = 0") + + for recording in range(len(times)): + + rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketimes[recording]) + relative_spike_times.extend(rel_spikes) + eod_durations.extend(eod_durs) + + vs = __vector_strength__(rel_spikes, eod_durs) + phases = calculate_phases(rel_spikes, eod_durs) + plot_polar(phases, "test_phase_locking_" + str(recording) + "_with_vs:" + str(round(vs, 3)) + ".png") + + print("VS of recording", recording, ":", vs) + + plot_phaselocking_testfigures(times[recording], eods[recording], spiketimes[recording], v1s[recording]) + + return __vector_strength__(relative_spike_times, eod_durations) + + +def eods_around_spikes(time, eod, spiketimes): + eod_durations = [] + relative_spike_times = [] + + for spike in spiketimes: + index = spike * 20 # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx + + if index != np.round(index): + print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index) + continue + index = int(index) + + start_time, end_time = search_eod_start_and_end_times(time, eod, index) + + eod_durations.append(end_time-start_time) + relative_spike_times.append(spike/1000 - start_time) + + return relative_spike_times, eod_durations + + +def search_eod_start_and_end_times(time, eod, index): + # TODO might break if a spike is in the cut off first or last eod! + + # search start_time: + previous = index + working_idx = index-1 + while True: + if eod[working_idx] < 0 < eod[previous]: + first_value = eod[working_idx] + second_value = eod[previous] + + dif = second_value - first_value + part = np.abs(first_value/dif) + + time_dif = np.abs(time[previous] - time[working_idx]) + start_time = time[working_idx] + time_dif*part + + break + + previous = working_idx + working_idx -= 1 + + # search end_time + previous = index + working_idx = index + 1 + while True: + if eod[previous] < 0 < eod[working_idx]: + first_value = eod[previous] + second_value = eod[working_idx] + + dif = second_value - first_value + part = np.abs(first_value / dif) + + time_dif = np.abs(time[previous] - time[working_idx]) + end_time = time[working_idx] + time_dif * part + + break + + previous = working_idx + working_idx += 1 + + return start_time, end_time + + +def search_closest_index(array, value, start=0, end=-1): + # searches the array to find the closest value in the array to the given value and returns its index. + # expects sorted array! + # start hast to be smaller than end + + if end == -1: + end = len(array)-1 + + while True: + if end-start <= 1: + return end if np.abs(array[end]-value) < np.abs(array[start]-value) else start + + middle = int(np.floor((end-start)/2)+start) + if array[middle] == value: + return middle + elif array[middle] > value: + end = middle + continue + else: + start = middle + continue + + +def __vector_strength__(relative_spike_times, eod_durations): + # adapted from Ramona + + n = len(relative_spike_times) + if n == 0: + return 0 + + phase_times = np.zeros(n) + + for i in range(n): + phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi + vs = np.sqrt((1 / n * sum(np.cos(phase_times))) ** 2 + (1 / n * sum(np.sin(phase_times))) ** 2) + + return vs + + +def calculate_phases(relative_spike_times, eod_durations): + phase_times = np.zeros(len(relative_spike_times)) + + for i in range(len(relative_spike_times)): + phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi + + return phase_times + + +def plot_polar(phases, name=""): + fig = plt.figure() + ax = fig.add_subplot(111, polar=True) + # r = np.arange(0, 1, 0.001) + # theta = 2 * 2 * np.pi * r + # line, = ax.plot(theta, r, color='#ee8d18', lw=3) + bins = np.arange(0, np.pi*2, 0.05) + ax.hist(phases, bins=bins) + if name == "": + plt.show() + else: + plt.savefig(get_savepath() + name) + plt.close() + + +def plot_phaselocking_testfigures(time, eod, spiketimes, v1): + eod_start_times = [] + eod_end_times = [] + + for spike in spiketimes: + index = spike * 20 # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx + + if index != np.round(index): + print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index) + continue + index = int(index) + + start_time, end_time = search_eod_start_and_end_times(time, eod, index) + + eod_start_times.append(start_time) + eod_end_times.append(end_time) + + cutoff_in_sec = 2 + sampling = 20000 + max_idx = cutoff_in_sec*sampling + spikes_part = [x/1000 for x in spiketimes if x/1000 < cutoff_in_sec] + count_spikes = len(spikes_part) + print(spiketimes) + print(len(spikes_part)) + + x_axis = time[0:max_idx] + plt.plot(spikes_part, np.ones(len(spikes_part))*-20, 'o') + plt.plot(x_axis, v1[0:max_idx]) + plt.plot(eod_start_times[: count_spikes], np.zeros(count_spikes), 'o') + plt.plot(eod_end_times[: count_spikes], np.zeros(count_spikes), 'o') + + plt.show() + plt.close() + + +if __name__ == '__main__': + main() diff --git a/introduction/introductionFICurve.py b/introduction/introductionFICurve.py new file mode 100644 index 0000000..1042f8f --- /dev/null +++ b/introduction/introductionFICurve.py @@ -0,0 +1,298 @@ +import numpy as np +import matplotlib.pyplot as plt +import pyrelacs.DataLoader as dl +import os +import helperFunctions as hf +from IPython import embed +from scipy.optimize import curve_fit +import warnings + +SAMPLING_INTERVAL = 1/20000 +STIMULUS_START = 0 +STIMULUS_DURATION = 0.400 +PRE_DURATION = 0.250 +TOTAL_DURATION = 1.25 + + +def main(): + for folder in hf.get_subfolder_paths("data/"): + filepath = folder + "/fispikes1.dat" + set_savepath("figures/" + folder.split('/')[1] + "/") + print("Folder:", folder) + + if not os.path.exists(get_savepath()): + os.makedirs(get_savepath()) + + spiketimes = [] + intensities = [] + index = -1 + for metadata, key, data in dl.iload(filepath): + # embed() + if len(metadata) != 0: + + metadata_index = 0 + if '----- Control --------------------------------------------------------' in metadata[0].keys(): + metadata_index = 1 + + print(metadata) + i = float(metadata[metadata_index]['intensity'][:-2]) + intensities.append(i) + spiketimes.append([]) + index += 1 + + spiketimes[index].append(data[:, 0]/1000) + + intensities, spiketimes = hf.merge_similar_intensities(intensities, spiketimes) + + # Sort the lists so that intensities are increasing + x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))] + intensities = x[0] + spiketimes = x[1] + + mean_frequencies = calculate_mean_frequencies(intensities, spiketimes) + popt, pcov = fit_exponential(intensities, mean_frequencies) + plot_frequency_curve(intensities, mean_frequencies) + + f_baseline = calculate_f_baseline(mean_frequencies) + f_infinity = calculate_f_infinity(mean_frequencies) + f_zero = calculate_f_zero(mean_frequencies) + + # plot_fi_curve(intensities, f_baseline, f_zero, f_infinity) + + +# TODO !! +def fit_exponential(intensities, mean_frequencies): + start_idx = int((PRE_DURATION + STIMULUS_START+0.005) / SAMPLING_INTERVAL) + end_idx = int((PRE_DURATION + STIMULUS_START + 0.1) / SAMPLING_INTERVAL) + time_constants = [] + #print(start_idx, end_idx) + + popts = [] + pcovs = [] + for i in range(len(mean_frequencies)): + freq = mean_frequencies[i] + y_values = freq[start_idx:end_idx+1] + x_values = np.arange(start_idx*SAMPLING_INTERVAL, end_idx*SAMPLING_INTERVAL, SAMPLING_INTERVAL) + try: + popt, pcov = curve_fit(exponential_function, x_values, y_values, p0=(1/(np.power(1, 10)), .5, 50, 180), maxfev=10000) + except RuntimeError: + print("RuntimeError happened in fit_exponential.") + continue + #print(popt) + #print(pcov) + #print() + + popts.append(popt) + pcovs.append(pcov) + + plt.plot(np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL), freq) + plt.plot(x_values-PRE_DURATION, [exponential_function(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values]) + # plt.show() + save_path = get_savepath() + "exponential_fits/" + if not os.path.exists(save_path): + os.makedirs(save_path) + plt.savefig(save_path + "fit_intensity:" + str(round(intensities[i], 4)) + ".png") + plt.close() + + return popts, pcovs + + +def calculate_mean_frequency(freqs): + mean_freq = [sum(e) / len(e) for e in zip(*freqs)] + + return mean_freq + + +def gaussian_kernel(sigma, dt): + x = np.arange(-4. * sigma, 4. * sigma, dt) + y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma + return y + + +def calculate_kernel_frequency(spiketimes, time, sampling_interval): + sp = spiketimes + t = time # Probably goes from -200ms to some amount of ms in the positive ~1200? + dt = sampling_interval + kernel_width = 0.01 # kernel width is a time in seconds how sharp the frequency should be counted + + binary = np.zeros(t.shape) + spike_indices = ((sp - t[0]) / dt).astype(int) + binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1 + g = gaussian_kernel(kernel_width, dt) + + rate = np.convolve(binary, g, mode='same') + + return rate + + +def calculate_isi_frequency(spiketimes, time): + first_isi = spiketimes[0] - (-PRE_DURATION) # diff to the start at 0 + last_isi = TOTAL_DURATION - spiketimes[-1] # diff from the last spike to the end of time :D + isis = [first_isi] + isis.extend(np.diff(spiketimes)) + isis.append(last_isi) + + if np.isnan(first_isi): + print(spiketimes[:10]) + print(isis[0:10]) + quit() + + rate = [] + for isi in isis: + if isi == 0: + print("probably a problem") + isi = 0.0000000001 + freq = 1/isi + frequency_step = int(round(isi*(1/SAMPLING_INTERVAL)))*[freq] + rate.extend(frequency_step) + + + #plt.plot((np.arange(len(rate))-PRE_DURATION)/(1/SAMPLING_INTERVAL), rate) + #plt.plot([sum(isis[:i+1]) for i in range(len(isis))], [200 for i in isis], 'o') + #plt.plot(time, [100 for t in time]) + #plt.show() + + if len(rate) != len(time): + if "12-13-af" in get_savepath(): + warnings.warn("preStimulus duration > 0 still not supported") + return [1]*len(time) + else: + print(len(rate), len(time), len(rate) - len(time)) + print(rate) + print(isis) + print("Quitting because time and rate aren't the same length") + quit() + + return rate + + +def calculate_mean_frequencies(intensities, spiketimes): + time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL) + + mean_frequencies = [] + for i in range(len(intensities)): + freqs = [] + for spikes in spiketimes[i]: + if len(spikes) < 2: + continue + freq = calculate_isi_frequency(spikes, time) + freqs.append(freq) + + mf = calculate_mean_frequency(freqs) + mean_frequencies.append(mf) + + return mean_frequencies + + +def calculate_f_baseline(mean_frequencies): + buffer_time = 0.05 + start_idx = int(0.05/SAMPLING_INTERVAL) + end_idx = int((PRE_DURATION - STIMULUS_START - buffer_time)/SAMPLING_INTERVAL) + + f_zeros = [] + for freq in mean_frequencies: + f_0 = np.mean(freq[start_idx:end_idx]) + f_zeros.append(f_0) + + return f_zeros + + +def calculate_f_infinity(mean_frequencies): + buffer_time = 0.05 + start_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - 0.15 - buffer_time) / SAMPLING_INTERVAL) + end_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - buffer_time) / SAMPLING_INTERVAL) + + f_infinity = [] + for freq in mean_frequencies: + f_inf = np.mean(freq[start_idx:end_idx]) + f_infinity.append(f_inf) + + return f_infinity + + +def calculate_f_zero(mean_frequencies): + buffer_time = 0.1 + start_idx = int((PRE_DURATION + STIMULUS_START - buffer_time) / SAMPLING_INTERVAL) + end_idx = int((PRE_DURATION + STIMULUS_START + buffer_time) / SAMPLING_INTERVAL) + f_peaks = [] + for freq in mean_frequencies: + fp = np.mean(freq[start_idx-500:start_idx]) + for i in range(start_idx+1, end_idx): + if abs(freq[i] - freq[start_idx]) > abs(fp - freq[start_idx]): + fp = freq[i] + f_peaks.append(fp) + return f_peaks + + +def plot_fi_curve(intensities, f_baseline, f_zero, f_infinity): + plt.plot(intensities, f_baseline, label="f_baseline") + plt.plot(intensities, f_zero, 'o', label="f_zero") + plt.plot(intensities, f_infinity, label="f_infinity") + + max_f0 = float(max(f_zero)) + mean_int = float(np.mean(intensities)) + start_k = float(((f_zero[-1] - f_zero[0]) / (intensities[-1] - intensities[0])*4)/f_zero[-1]) + + popt, pcov = curve_fit(fill_boltzmann, intensities, f_zero, p0=(max_f0, start_k, mean_int), maxfev=10000) + print(popt) + min_x = min(intensities) + max_x = max(intensities) + step = (max_x - min_x) / 5000 + x_values_boltzmann_fit = np.arange(min_x, max_x, step) + plt.plot(x_values_boltzmann_fit, [fill_boltzmann(i, popt[0], popt[1], popt[2]) for i in x_values_boltzmann_fit], label='fit') + + plt.title("FI-Curve") + plt.ylabel("Frequency in Hz") + plt.xlabel("Intensity in mV") + plt.legend() + # plt.show() + plt.savefig(get_savepath() + "fi_curve.png") + plt.close() + + +def plot_frequency_curve(intensities, mean_frequencies): + colors = ["red", "green", "blue", "violet", "orange", "grey"] + + time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL) + + for i in range(len(intensities)): + plt.plot(time, mean_frequencies[i], color=colors[i % 6], label=str(intensities[i])) + + plt.plot((0, 0), (0, 500), color="black") + plt.plot((0.4, 0.4), (0, 500), color="black") + plt.legend() + plt.xlabel("Time in seconds") + plt.ylabel("Frequency in Hz") + plt.title("Frequency curve") + + plt.savefig(get_savepath() + "mean_frequency_curves.png") + plt.close() + + +def exponential_function(x, a, b, c, d): + return a*np.exp(-c*(x-b))+d + + +def upper_boltzmann(x, f_max, k, x_zero): + return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None) + + +def fill_boltzmann(x, f_max, k, x_zero): + return f_max * (1 / (1 + np.power(np.e, -k * (x - x_zero)))) + + +SAVEPATH = "" + + +def get_savepath(): + global SAVEPATH + return SAVEPATH + + +def set_savepath(new_path): + global SAVEPATH + SAVEPATH = new_path + + +if __name__ == '__main__': + main() diff --git a/introduction/janExample.py b/introduction/janExample.py new file mode 100644 index 0000000..272aaa6 --- /dev/null +++ b/introduction/janExample.py @@ -0,0 +1,20 @@ +import pyrelacs.DataLoader as dl + +for metadata, key, data in dl.iload('2012-06-27-ah-invivo-1/basespikes1.dat'): + print(data.shape) + break + +# mean firing rate (number of spikes per time) +# CV (stdev of ISI divided by mean ISI (np.diff(spiketimes)) +# ISI histogram (play around with binsize! < 1ms) +# Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8 +# dl.iload_traces(repro='BaselineActivity') + +def test(): + for metadata, key, data in dl.iload('data/2012-06-27-ah-invivo-1/basespikes1.dat'): + print(data.shape) + for i in metadata: + for key in i.keys(): + print(key, ":", i[key]) + + break \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..244ae25 --- /dev/null +++ b/main.py @@ -0,0 +1,61 @@ + +from FiCurve import FICurve +from CellData import icelldata_of_dir +import os +import helperFunctions as hf +from AdaptionCurrent import Adaption +from models.NeuronModel import NeuronModel +from functionalityTests import * +# TODO command line interface needed/nice ? + + +def main(): + run_tests() + quit() + for cell_data in icelldata_of_dir("./data/"): + print() + print(cell_data.get_data_path()) + + model = NeuronModel(cell_data) + + x_values = np.arange(0, 1000, 0.01) + stimulus = [0]*int(200/0.01) + stimulus.extend([0.19]*int(400/0.01)) + stimulus.extend([0]*int(400/0.01)) + + v, spikes = model.simulate(0, 1000, stimulus) + + # plt.plot(x_values, v) + + spikes = [s/1000 for s in spikes] + time, freq = hf.calculate_isi_frequency(spikes, 0, 0.01/1000) + + plt.plot(time, freq) + plt.show() + + quit() + continue + + figures_save_path = "./figures/" + os.path.basename(cell_data.get_data_path()) + "/" + ficurve = FICurve(cell_data) + ficurve.plot_fi_curve(figures_save_path) + + adaption = Adaption(cell_data, ficurve) + adaption.plot_exponential_fits(figures_save_path + "exponential_fits/", delete_previous=True) + + for i in range(len(adaption.exponential_fit_vars)): + if len(adaption.exponential_fit_vars[i]) == 0: + continue + tau = round(adaption.exponential_fit_vars[i][1]*1000, 2) + contrast = round(ficurve.stimulus_value[i], 3) + # print(tau, "ms - tau_eff at", contrast, "contrast") + + # test_plot_inverses(ficurve) + print("Chosen tau [ms]:", adaption.tau_real) + + +def run_tests(): + test_lifac() + +if __name__ == '__main__': + main() diff --git a/models/LIFAC.py b/models/LIFAC.py new file mode 100644 index 0000000..61b1daf --- /dev/null +++ b/models/LIFAC.py @@ -0,0 +1,88 @@ + +from stimuli.AbstractStimulus import AbstractStimulus +import numpy as np + + +class LIFACModel: + # all times in milliseconds + KEYS = ["mem_res", "mem_tau", "v_base", "v_zero", "threshold", "step_size", "delta_a", "tau_a"] + VALUES = [100 * 1000000, 0.1 * 200, 0, 0, 10, 0.01, 1, 200] + + # membrane time constant tau = mem_cap*mem_res + def __init__(self, params: dict = None): + self.parameters = {} + if params is None: + self._set_default_parameters() + else: + self._test_given_parameters(params) + self.set_parameters(params) + + self.last_v = [] + self.last_adaption = [] + self.last_spiketimes = [] + + def __call__(self, stimulus: AbstractStimulus, total_time_s): + output_voltage = [] + adaption = [] + spiketimes = [] + current_v = self.parameters["v_zero"] + current_a = 0 + + for time_point in np.arange(0, total_time_s*1000, self.parameters["step_size"]): + v_next = self._calculate_voltage_step(current_v, stimulus.value_at_time_in_ms(time_point) - current_a) + a_next = self._calculate_adaption_step(current_a) + + if v_next > self.parameters["threshold"]: + v_next = self.parameters["v_base"] + spiketimes.append(time_point/1000) + a_next += self.parameters["delta_a"] + + output_voltage.append(v_next) + adaption.append(a_next) + + current_v = v_next + current_a = a_next + + self.last_v = output_voltage + self.last_adaption = adaption + self.last_spiketimes = spiketimes + + return output_voltage, spiketimes + + def _calculate_voltage_step(self, current_v, input_v): + v_base = self.parameters["v_base"] + step_size = self.parameters["step_size"] + # mem_res = self.parameters["mem_res"] + mem_tau = self.parameters["mem_tau"] + return current_v + (step_size * (v_base - current_v + input_v)) / mem_tau + + def _calculate_adaption_step(self, current_a): + step_size = self.parameters["step_size"] + return current_a + (step_size * (-current_a)) / self.parameters["tau_a"] + + def set_parameters(self, params): + for k in params.keys(): + self.parameters[k] = params[k] + + for i in range(len(self.KEYS)): + if self.KEYS[i] not in self.parameters.keys(): + self.parameters[self.KEYS[i]] = self.VALUES[i] + + def get_parameters(self): + return self.parameters + + def set_variable(self, key, value): + if key not in self.KEYS: + raise ValueError("Given key is unknown!\n" + "Please check spelling and refer to list LIFAC.KEYS.") + self.parameters[key] = value + + def _set_default_parameters(self): + for i in range(len(self.KEYS)): + self.parameters[self.KEYS[i]] = self.VALUES[i] + + def _test_given_parameters(self, params): + for k in params.keys(): + if k not in self.KEYS: + err_msg = "Unknown key in the given parameters:" + str(k) + raise ValueError(err_msg) \ No newline at end of file diff --git a/models/LeakyIntegrateFireModel.py b/models/LeakyIntegrateFireModel.py new file mode 100644 index 0000000..567f155 --- /dev/null +++ b/models/LeakyIntegrateFireModel.py @@ -0,0 +1,37 @@ + + +class LIFModel: + # all times in milliseconds + def __init__(self, mem_res, mem_tau, v_base, v_zero, input_current, threshold, input_offset=0, step_size=0.01): + self.mem_res = mem_res + # self.membrane_capacitance = mem_cap + self.mem_tau = mem_tau # membrane time constant tau = mem_cap*mem_res + self.v_base = v_base + self.v_zero = v_zero + self.threshold = threshold + + self.step_size = step_size + self.input_current = input_current + self.input_offset = input_offset + + def calculate_response(self): + output_voltage = [self.v_zero] + spikes = [] + + for idx in range(1, len(self.input_current)): + v_next = self.__calculate_next_step__(output_voltage[idx-1], self.input_current[idx-1]) + if v_next > self.threshold: + v_next = self.v_base + spikes.append(True) + else: + spikes.append(False) + output_voltage.append(v_next) + + return output_voltage, spikes + + def set_input_current(self, input_current, offset=0): + self.input_current = input_current + self.input_offset = offset + + def __calculate_next_step__(self, current_v, input_i): + return current_v + (self.step_size * (self.v_base - current_v + self.mem_res * input_i)) / self.mem_tau diff --git a/models/NeuronModel.py b/models/NeuronModel.py new file mode 100644 index 0000000..c912f1b --- /dev/null +++ b/models/NeuronModel.py @@ -0,0 +1,110 @@ + +from CellData import CellData +from FiCurve import FICurve +from AdaptionCurrent import Adaption +import numpy as np +import matplotlib.pyplot as plt + + +class NeuronModel: + KEYS = ["mem_res", "mem_tau", "v_base", "v_zero", "threshold", "step_size"] + VALUES = [100 * 1000000, 0.1 * 200, 0, 0, 10, 0.01] + + def __init__(self, cell_data: CellData, variables: dict = None): + self.cell_data = cell_data + self.fi_curve = FICurve(cell_data) + self.adaption = Adaption(cell_data, self.fi_curve) + + if variables is not None: + self._test_given_variables(variables) + self.variables = variables + else: + self.variables = {} + self._add_standard_variables() + + def __call__(self, stimulus): + raise NotImplementedError("Soon. sorry!") + + def _approximate_variables_from_data(self): + # TODO don't return but save in class in some form! approximate/calculate other variables? + base_input = self._calculate_input_fro_base_frequency() + return base_input + + def simulate(self, start_v, time_in_ms, stimulus): + response = [] + spikes = [] + current_v = start_v + current_a = 0 + base_input = self._calculate_input_fro_base_frequency() + + adaption_values = [] + a_infties = [] + print("base input:", base_input) + for time_step in np.arange(0, time_in_ms, self.variables["step_size"]): + stimulus_input = stimulus[int(time_step/self.variables["step_size"])] - current_a + + new_v = self._calculate_next_step(current_v, current_a*base_input, base_input + base_input*stimulus_input) + new_a, a_infty = self._calculate_adaption_step(current_a, stimulus_input) + + if new_v > self.variables["threshold"]: + new_v = self.variables["v_base"] + spikes.append(time_step) + response.append(new_v) + + adaption_values.append(current_a) + a_infties.append(a_infty) + current_v = new_v + current_a = new_a + + plt.title("Adaption variable") + plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), np.array(adaption_values), label="adaption") + plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), np.array(a_infties), label="a_inf") + plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), stimulus, label="stimulus") + plt.legend() + plt.xlabel("time in ms") + plt.ylabel("value as contrast?") + plt.show() + plt.close() + + return response, spikes + + def _calculate_next_step(self, current_v, current_a, input_v): + step_size = self.variables["step_size"] + v_base = self.variables["v_base"] + mem_tau = self.variables["mem_tau"] + + return current_v + (step_size * (- current_v + v_base + input_v - current_a)) / mem_tau + + def _calculate_adaption_step(self, current_a, stimulus_input): + step_size = self.variables["step_size"] + tau_a = self.adaption.tau_real + f_infty_freq = self.fi_curve.get_f_infinity_frequency_at_stimulus_value(stimulus_input) + a_infinity = stimulus_input - self.fi_curve.get_f_zero_inverse_at_frequency(f_infty_freq) + return current_a + (step_size * (- current_a + a_infinity)) / tau_a, a_infinity + + def set_variable(self, key, value): + if key not in self.KEYS: + raise ValueError("Given key is unknown!\n" + "Please check spelling and refer to list NeuronModel.KEYS.") + self.variables[key] = value + + def set_variables(self, variables: dict): + self._test_given_variables(variables) + + for k in variables.keys(): + self.variables[k] = variables[k] + + def _calculate_input_fro_base_frequency(self): + return - self.variables["threshold"] / ( + np.e ** (-1 / (self.cell_data.get_base_frequency()/1000 * self.variables["mem_tau"])) - 1) + + def _test_given_variables(self, variables: dict): + for k in variables.keys(): + if k not in self.KEYS: + raise ValueError("Unknown key in given model variables. \n" + "Please check spelling and refer to list NeuronModel.KEYS.") + + def _add_standard_variables(self): + for i in range(len(self.KEYS)): + if self.KEYS[i] not in self.variables: + self.variables[self.KEYS[i]] = self.VALUES[i] diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stimuli/AbstractStimulus.py b/stimuli/AbstractStimulus.py new file mode 100644 index 0000000..0745fa5 --- /dev/null +++ b/stimuli/AbstractStimulus.py @@ -0,0 +1,8 @@ + +class AbstractStimulus: + + def value_at_time_in_ms(self, time_point): + raise NotImplementedError("This is an abstract class!") + + def value_at_time_in_s(self, time_point): + raise NotImplementedError("This is an abstract class!") diff --git a/stimuli/StepStimulus.py b/stimuli/StepStimulus.py new file mode 100644 index 0000000..261aefb --- /dev/null +++ b/stimuli/StepStimulus.py @@ -0,0 +1,26 @@ + +from stimuli.AbstractStimulus import AbstractStimulus + + +class StepStimulus(AbstractStimulus): + + def __init__(self, start, duration, value, base_value=0, seconds=True): + self.start = 0 + self.duration = 0 + self.base_value = base_value + self.value = value + if seconds: + self.start = start + self.duration = duration + else: + self.start = start / 1000 + self.duration = duration / 1000 + + def value_at_time_in_ms(self, time_point): + return self.value_at_time_in_s(time_point/1000) + + def value_at_time_in_s(self, time_point): + if self.start <= time_point <= self.start + self.duration: + return self.value + else: + return self.base_value diff --git a/stimuli/__init__.py b/stimuli/__init__.py new file mode 100644 index 0000000..e69de29