From f5dc213e42d8abfbced5eb0128cb38e8aa5a35df Mon Sep 17 00:00:00 2001 From: "a.ott" Date: Fri, 20 Dec 2019 13:33:34 +0100 Subject: [PATCH] commit all existing code --- AdaptionCurrent.py | 168 +++++++++++++++ CellData.py | 156 ++++++++++++++ DataParserFactory.py | 237 +++++++++++++++++++++ FiCurve.py | 153 ++++++++++++++ functionalityTests.py | 44 ++++ functions.py | 40 ++++ generalTests.py | 54 +++++ helperFunctions.py | 214 +++++++++++++++++++ introduction/introductionBaseline.py | 283 +++++++++++++++++++++++++ introduction/introductionFICurve.py | 298 +++++++++++++++++++++++++++ introduction/janExample.py | 20 ++ main.py | 61 ++++++ models/LIFAC.py | 88 ++++++++ models/LeakyIntegrateFireModel.py | 37 ++++ models/NeuronModel.py | 110 ++++++++++ models/__init__.py | 0 stimuli/AbstractStimulus.py | 8 + stimuli/StepStimulus.py | 26 +++ stimuli/__init__.py | 0 19 files changed, 1997 insertions(+) create mode 100644 AdaptionCurrent.py create mode 100644 CellData.py create mode 100644 DataParserFactory.py create mode 100644 FiCurve.py create mode 100644 functionalityTests.py create mode 100644 functions.py create mode 100644 generalTests.py create mode 100644 helperFunctions.py create mode 100644 introduction/introductionBaseline.py create mode 100644 introduction/introductionFICurve.py create mode 100644 introduction/janExample.py create mode 100644 main.py create mode 100644 models/LIFAC.py create mode 100644 models/LeakyIntegrateFireModel.py create mode 100644 models/NeuronModel.py create mode 100644 models/__init__.py create mode 100644 stimuli/AbstractStimulus.py create mode 100644 stimuli/StepStimulus.py create mode 100644 stimuli/__init__.py diff --git a/AdaptionCurrent.py b/AdaptionCurrent.py new file mode 100644 index 0000000..97e54d2 --- /dev/null +++ b/AdaptionCurrent.py @@ -0,0 +1,168 @@ + +from FiCurve import FICurve +from CellData import CellData +import matplotlib.pyplot as plt +from scipy.optimize import curve_fit +import os +import numpy as np +import functions as fu + + +class Adaption: + + def __init__(self, cell_data: CellData, fi_curve: FICurve = None): + self.cell_data = cell_data + if fi_curve is None: + self.fi_curve = FICurve(cell_data) + else: + self.fi_curve = fi_curve + + # [[a, tau_eff, c], [], [a, tau_eff, c], ...] + self.exponential_fit_vars = [] + self.tau_real = [] + + self.fit_exponential() + self.calculate_tau_from_tau_eff() + + def fit_exponential(self, length_of_fit=0.05): + mean_frequencies = self.cell_data.get_mean_isi_frequencies() + time_axes = self.cell_data.get_time_axes_mean_frequencies() + for i in range(len(mean_frequencies)): + start_idx = self.__find_start_idx_for_exponential_fit(i) + + if start_idx == -1: + self.exponential_fit_vars.append([]) + continue + + # shorten length of fit to stay in stimulus region if given length is too long + sampling_interval = self.cell_data.get_sampling_interval() + used_length_of_fit = length_of_fit + if (start_idx * sampling_interval) - self.cell_data.get_delay() + length_of_fit > self.cell_data.get_stimulus_end(): + print(start_idx * sampling_interval, "start - end", start_idx * sampling_interval + length_of_fit) + print("Shortened length of fit to keep it in the stimulus region!") + used_length_of_fit = self.cell_data.get_stimulus_end() - (start_idx * sampling_interval) + + end_idx = start_idx + int(used_length_of_fit/sampling_interval) + y_values = mean_frequencies[i][start_idx:end_idx+1] + x_values = time_axes[i][start_idx:end_idx+1] + + tau = self.__approximate_tau_for_exponential_fit(x_values, y_values, i) + + # start the actual fit: + try: + p0 = (self.fi_curve.f_zeros[i], tau, self.fi_curve.f_infinities[i]) + popt, pcov = curve_fit(fu.exponential_function, x_values, y_values, + p0=p0, maxfev=10000, bounds=([-np.inf, 0, -np.inf], [np.inf, np.inf, np.inf])) + except RuntimeError: + print("RuntimeError happened in fit_exponential.") + self.exponential_fit_vars.append([]) + continue + + # Obviously a bad fit - time constant, expected in range 3-10ms, has value over 1 second or is negative + if abs(popt[1] > 1) or popt[1] < 0: + self.exponential_fit_vars.append([]) + else: + self.exponential_fit_vars.append(popt) + + def __approximate_tau_for_exponential_fit(self, x_values, y_values, mean_freq_idx): + if self.fi_curve.f_infinities[mean_freq_idx] < self.fi_curve.f_baselines[mean_freq_idx] * 0.95: + test_val = [y > 0.65 * self.fi_curve.f_infinities[mean_freq_idx] for y in y_values] + else: + test_val = [y < 0.65 * self.fi_curve.f_zeros[mean_freq_idx] for y in y_values] + + try: + idx = test_val.index(True) + if idx == 0: + idx = 1 + tau = x_values[idx] - x_values[0] + except ValueError: + tau = x_values[-1] - x_values[0] + + return tau + + def __find_start_idx_for_exponential_fit(self, mean_freq_idx): + stimulus_start_idx = int((self.cell_data.get_delay() + self.cell_data.get_stimulus_start()) / self.cell_data.get_sampling_interval()) + if self.fi_curve.f_infinities[mean_freq_idx] > self.fi_curve.f_baselines[mean_freq_idx] * 1.1: + # start setting starting variables for the fit + # search for the start_index by searching for the max + j = 0 + while True: + try: + if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j] == self.fi_curve.f_zeros[mean_freq_idx]: + start_idx = stimulus_start_idx + j + break + except IndexError as e: + return -1 + + j += 1 + + elif self.fi_curve.f_infinities[mean_freq_idx] < self.fi_curve.f_baselines[mean_freq_idx] * 0.9: + # start setting starting variables for the fit + # search for start by finding the end of the minimum + found_min = False + j = int(0.05 / self.cell_data.get_sampling_interval()) + nothing_to_fit = False + while True: + if not found_min: + if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j] == self.fi_curve.f_zeros[mean_freq_idx]: + found_min = True + else: + if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j + 1] > self.fi_curve.f_zeros[mean_freq_idx]: + start_idx = stimulus_start_idx + j + break + if j > 0.1 / self.cell_data.get_sampling_interval(): + # no rise in freq until to close to the end of the stimulus (to little place to fit) + return -1 + j += 1 + + if nothing_to_fit: + return -1 + else: + # there is nothing to fit to: + return -1 + + return start_idx + + def calculate_tau_from_tau_eff(self): + taus = [] + for i in range(len(self.exponential_fit_vars)): + if len(self.exponential_fit_vars[i]) == 0: + continue + tau_eff = self.exponential_fit_vars[i][1]*1000 # tau_eff in ms + # intensity = self.fi_curve.stimulus_value[i] + f_infinity_slope = self.fi_curve.get_f_infinity_slope() + fi_curve_slope = self.fi_curve.get_fi_curve_slope_of_straight() + + taus.append(tau_eff*(fi_curve_slope/f_infinity_slope)) + # print((fi_curve_slope/f_infinity_slope)) + # print(tau_eff*(fi_curve_slope/f_infinity_slope), "=", tau_eff, "*", (fi_curve_slope/f_infinity_slope)) + + self.tau_real = np.median(taus) + + def plot_exponential_fits(self, save_path: str = None, indices: list = None, delete_previous: bool = False): + if delete_previous: + for val in self.cell_data.get_fi_contrasts(): + + prev_path = save_path + "mean_freq_exp_fit_contrast:" + str(round(val, 3)) + ".png" + + if os.path.exists(prev_path): + os.remove(prev_path) + + for i in range(len(self.cell_data.get_fi_contrasts())): + if self.exponential_fit_vars[i] == []: + continue + + plt.plot(self.cell_data.get_time_axes_mean_frequencies()[i], self.cell_data.get_mean_isi_frequencies()[i]) + vars = self.exponential_fit_vars[i] + fit_x = np.arange(0, 0.4, self.cell_data.get_sampling_interval()) + plt.plot(fit_x, [fu.exponential_function(x, vars[0], vars[1], vars[2]) for x in fit_x]) + plt.ylim([0, max(self.fi_curve.f_zeros[i], self.fi_curve.f_baselines[i])*1.1]) + plt.xlabel("Time [s]") + plt.ylabel("Frequency [Hz]") + + if save_path is None: + plt.show() + else: + plt.savefig(save_path + "mean_freq_exp_fit_contrast:" + str(round(self.cell_data.get_fi_contrasts()[i], 3)) + ".png") + + plt.close() \ No newline at end of file diff --git a/CellData.py b/CellData.py new file mode 100644 index 0000000..1eb7661 --- /dev/null +++ b/CellData.py @@ -0,0 +1,156 @@ + +import DataParserFactory as dpf +from warnings import warn +from os import listdir +import helperFunctions as hf +import numpy as np + + +def icelldata_of_dir(base_path): + for item in sorted(listdir(base_path)): + item_path = base_path + item + + try: + yield CellData(item_path) + except TypeError as e: + warn_msg = str(e) + warn(warn_msg) + + +class CellData: + # Class to capture all the data of a single cell across all experiments (base rate, FI-curve, .?.) + # should be abstract from the way the data is saved in the background .dat vs .nix + + # traces list of lists with traces: [[time], [voltage (v1)], [EOD], [local eod], [stimulus]] + TIME = 0 + V1 = 1 + EOD = 2 + LOCAL_EOD = 3 + STIMULUS = 4 + + def __init__(self, data_path): + self.data_path = data_path + self.base_traces = None + # self.fi_traces = None + self.fi_intensities = None + self.fi_spiketimes = None + self.fi_trans_amplitudes = None + self.mean_isi_frequencies = None + self.time_axes = None + # self.metadata = None + self.parser = dpf.get_parser(data_path) + + self.sampling_interval = self.parser.get_sampling_interval() + self.recording_times = self.parser.get_recording_times() + + def get_data_path(self): + return self.data_path + + def get_base_traces(self, trace_type=None): + if self.base_traces is None: + self.base_traces = self.parser.get_baseline_traces() + + if trace_type is None: + return self.base_traces + else: + return self.base_traces[trace_type] + + def get_fi_traces(self): + raise NotImplementedError("CellData:get_fi_traces():\n" + + "Getting the Fi-Traces currently overflows the RAM and causes swapping! Reimplement if really needed!") + # if self.fi_traces is None: + # self.fi_traces = self.parser.get_fi_curve_traces() + # return self.fi_traces + + def get_fi_spiketimes(self): + self.__read_fi_spiketimes_info__() + return self.fi_spiketimes + + def get_fi_intensities(self): + self.__read_fi_spiketimes_info__() + return self.fi_intensities + + def get_fi_contrasts(self): + self.__read_fi_spiketimes_info__() + contrast = [] + for i in range(len(self.fi_intensities)): + contrast.append((self.fi_intensities[i] - self.fi_trans_amplitudes[i]) / self.fi_trans_amplitudes[i]) + + return contrast + + def get_mean_isi_frequencies(self): + if self.mean_isi_frequencies is None: + self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequencies(self.get_fi_spiketimes(), + self.get_time_start(), + self.get_sampling_interval()) + return self.mean_isi_frequencies + + def get_time_axes_mean_frequencies(self): + if self.time_axes is None: + self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequencies(self.get_fi_spiketimes(), + self.get_time_start(), + self.get_sampling_interval()) + return self.time_axes + + def get_base_frequency(self): + base_freqs = [] + for freq in self.get_mean_isi_frequencies(): + delay = self.get_delay() + sampling_interval = self.get_sampling_interval() + if delay < 0.1: + warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.") + + idx_start = int(0.025 / sampling_interval) + idx_end = int((delay - 0.025) / sampling_interval) + base_freqs.append(np.mean(freq[idx_start:idx_end])) + + return np.median(base_freqs) + + def get_sampling_interval(self) -> float: + return self.sampling_interval + + def get_recording_times(self) -> list: + return self.recording_times + + def get_time_start(self) -> float: + return self.recording_times[0] + + def get_delay(self) -> float: + return abs(self.recording_times[0]) + + def get_time_end(self) -> float: + return self.recording_times[2] + self.recording_times[3] + + def get_stimulus_start(self) -> float: + return self.recording_times[1] + + def get_stimulus_duration(self) -> float: + return self.recording_times[2] + + def get_stimulus_end(self) -> float: + return self.get_stimulus_start() + self.get_stimulus_duration() + + def get_after_stimulus_duration(self) -> float: + return self.recording_times[3] + + def __read_fi_spiketimes_info__(self): + if self.fi_spiketimes is None: + trans_amplitudes, intensities, spiketimes = self.parser.get_fi_curve_spiketimes() + + self.fi_intensities, self.fi_spiketimes, self.fi_trans_amplitudes = hf.merge_similar_intensities(intensities, spiketimes, trans_amplitudes) + + # def get_metadata(self): + # self.__read_metadata__() + # return self.metadata + # + # def get_metadata_item(self, item): + # self.__read_metadata__() + # if item in self.metadata.keys(): + # return self.metadata[item] + # else: + # raise KeyError("CellData:get_metadata_item: Item not found in metadata! - " + str(item)) + # + # def __read_metadata__(self): + # if self.metadata is None: + # # TODO!! + # pass diff --git a/DataParserFactory.py b/DataParserFactory.py new file mode 100644 index 0000000..3ae7cea --- /dev/null +++ b/DataParserFactory.py @@ -0,0 +1,237 @@ + +from os.path import isdir, exists +from warnings import warn +import pyrelacs.DataLoader as Dl + +UNKNOWN = -1 +DAT_FORMAT = 0 +NIX_FORMAT = 1 + + +class AbstractParser: + + def cell_get_metadata(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_baseline_traces(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_fi_curve_traces(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_fi_curve_spiketimes(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_sampling_interval(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + def get_recording_times(self): + raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS") + + +class DatParser(AbstractParser): + + def __init__(self, dir_path): + self.base_path = dir_path + self.fi_file = self.base_path + "/fispikes1.dat" + self.stimuli_file = self.base_path + "/stimuli.dat" + self.__test_data_file_existence__() + + self.fi_recording_times = [] + self.sampling_interval = -1 + + def cell_get_metadata(self): + pass + + def get_sampling_interval(self): + if self.sampling_interval == -1: + self.__read_sampling_interval__() + + return self.sampling_interval + + def get_recording_times(self): + if len(self.fi_recording_times) == 0: + self.__read_fi_recording_times__() + return self.fi_recording_times + + def get_baseline_traces(self): + return self.__get_traces__("BaselineActivity") + + def get_fi_curve_traces(self): + return self.__get_traces__("FICurve") + + # TODO clean up/ rewrite + def get_fi_curve_spiketimes(self): + spiketimes = [] + pre_intensities = [] + pre_durations = [] + intensities = [] + trans_amplitudes = [] + pre_duration = -1 + index = -1 + skip = False + trans_amplitude = float('nan') + for metadata, key, data in Dl.iload(self.fi_file): + if len(metadata) != 0: + + metadata_index = 0 + + if '----- Control --------------------------------------------------------' in metadata[0].keys(): + metadata_index = 1 + pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2]) + trans_amplitude = float(metadata[0]["trans. amplitude"][:-2]) + if pre_duration == 0: + skip = False + else: + skip = True + continue + + if skip: + continue + + intensity = float(metadata[metadata_index]['intensity'][:-2]) + pre_intensity = float(metadata[metadata_index]['preintensity'][:-2]) + + intensities.append(intensity) + pre_durations.append(pre_duration) + pre_intensities.append(pre_intensity) + trans_amplitudes.append(trans_amplitude) + spiketimes.append([]) + index += 1 + + if skip: + continue + + if data.shape[1] != 1: + raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!") + + spike_time_data = data[:, 0]/1000 + if len(spike_time_data) < 10: + continue + if spike_time_data[-1] < 1: + print("# ignoring spike-train that ends before one second.") + continue + + spiketimes[index].append(spike_time_data) + + # TODO add merging for similar intensities? hf.merge_similar_intensities() + trans_amplitudes + + return trans_amplitudes, intensities, spiketimes + + def __get_traces__(self, repro): + time_traces = [] + v1_traces = [] + eod_traces = [] + local_eod_traces = [] + stimulus_traces = [] + + nothing = True + + for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro): + nothing = False + time_traces.append(time) + v1_traces.append(x[0]) + eod_traces.append(x[1]) + local_eod_traces.append(x[2]) + stimulus_traces.append(x[3]) + + traces = [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces] + + if nothing: + warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!" + warn(warn_msg) + + return traces + + def __read_fi_recording_times__(self): + + delays = [] + stim_duration = [] + pause = [] + + for metadata, key, data in Dl.iload(self.fi_file): + if len(metadata) != 0: + control_key = '----- Control --------------------------------------------------------' + if control_key in metadata[0].keys(): + delays.append(float(metadata[0][control_key]["delay"][:-2])/1000) + pause.append(float(metadata[0][control_key]["pause"][:-2])/1000) + stim_key = "----- Test-Intensities -----------------------------------------------" + stim_duration.append(float(metadata[0][stim_key]["duration"][:-2])/1000) + + for l in [delays, stim_duration, pause]: + if len(l) == 0: + raise RuntimeError("DatParser:__read_fi_recording_times__:\n" + + "Couldn't find any delay, stimulus duration and or pause in the metadata.\n" + + "In file:" + self.base_path) + elif len(set(l)) != 1: + raise RuntimeError("DatParser:__read_fi_recording_times__:\n" + + "Found multiple different delay, stimulus duration and or pause in the metadata.\n" + + "In file:" + self.base_path) + else: + self.fi_recording_times = [-delays[0], 0, stim_duration[0], pause[0] - delays[0]] + + def __read_sampling_interval__(self): + stop = False + sampling_intervals = [] + for metadata, key, data in Dl.iload(self.stimuli_file): + for md in metadata: + for i in range(4): + key = "sample interval" + str(i+1) + if key in md.keys(): + + sampling_intervals.append(float(md[key][:-2]) / 1000) + stop = True + else: + break + + if stop: + break + + if len(sampling_intervals) == 0: + raise RuntimeError("DatParser:__read_sampling_interval__:\n" + + "Sampling intervals not found in stimuli.dat this is not handled!\n" + + "with File:" + self.base_path) + + if len(set(sampling_intervals)) != 1: + raise RuntimeError("DatParser:__read_sampling_interval__:\n" + + "Sampling intervals not the same for all traces this is not handled!\n" + + "with File:" + self.base_path) + else: + self.sampling_interval = sampling_intervals[0] + + def __test_data_file_existence__(self): + if not exists(self.stimuli_file): + raise RuntimeError(self.stimuli_file + " file doesn't exist!") + if not exists(self.fi_file): + raise RuntimeError(self.fi_file + " file doesn't exist!") + + +# TODO #################################### +class NixParser(AbstractParser): + + def __init__(self, nix_file_path): + self.file_path = nix_file_path + warn("NIX PARSER: NOT YET IMPLEMENTED!") +# TODO #################################### + + +def get_parser(data_path: str) -> AbstractParser: + data_format = __test_for_format__(data_path) + + if data_format == DAT_FORMAT: + return DatParser(data_path) + elif data_format == NIX_FORMAT: + return NixParser(data_path) + elif UNKNOWN: + raise TypeError("DataParserFactory:get_parser(data_path):\nCannot determine type of data for:" + data_path) + + +def __test_for_format__(data_path): + if isdir(data_path): + if exists(data_path + "/fispikes1.dat"): + return DAT_FORMAT + + elif data_path.endswith(".nix"): + return NIX_FORMAT + else: + return UNKNOWN diff --git a/FiCurve.py b/FiCurve.py new file mode 100644 index 0000000..3903518 --- /dev/null +++ b/FiCurve.py @@ -0,0 +1,153 @@ + +from CellData import CellData +import numpy as np +from scipy.optimize import curve_fit +import matplotlib.pyplot as plt +from warnings import warn +import functions as fu + + +class FICurve: + + def __init__(self, cell_data: CellData, contrast: bool = True): + self.cell_data = cell_data + self.using_contrast = contrast + + if contrast: + self.stimulus_value = cell_data.get_fi_contrasts() + else: + self.stimulus_value = cell_data.get_fi_intensities() + + self.f_zeros = [] + self.f_infinities = [] + self.f_baselines = [] + + # f_max, f_min, k, x_zero + self.boltzmann_fit_vars = [] + # offset increase + self.f_infinity_fit = [] + + self.all_calculate_frequency_points() + self.fit_line() + self.fit_boltzmann() + + def all_calculate_frequency_points(self): + mean_frequencies = self.cell_data.get_mean_isi_frequencies() + if len(mean_frequencies) == 0: + warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n" + "Was all_calculate_mean_isi_frequencies already called?") + + for freq in mean_frequencies: + self.f_zeros.append(self.__calculate_f_zero__(freq)) + self.f_baselines.append(self.__calculate_f_baseline__(freq)) + self.f_infinities.append(self.__calculate_f_infinity__(freq)) + + def fit_line(self): + popt, pcov = curve_fit(fu.clipped_line, self.stimulus_value, self.f_infinities) + self.f_infinity_fit = popt + + def fit_boltzmann(self): + max_f0 = float(max(self.f_zeros)) + min_f0 = float(min(self.f_zeros)) + mean_int = float(np.mean(self.stimulus_value)) + + total_increase = max_f0 - min_f0 + total_change_int = max(self.stimulus_value) - min(self.stimulus_value) + start_k = float((total_increase / total_change_int * 4) / max_f0) + + popt, pcov = curve_fit(fu.full_boltzmann, self.stimulus_value, self.f_zeros, + p0=(max_f0, min_f0, start_k, mean_int), + maxfev=10000, bounds=([0, 0, -np.inf, -np.inf], [3000, 3000, np.inf, np.inf])) + + self.boltzmann_fit_vars = popt + + def plot_fi_curve(self, savepath: str = None): + min_x = min(self.stimulus_value) + max_x = max(self.stimulus_value) + step = (max_x - min_x) / 5000 + x_values = np.arange(min_x, max_x, step) + + plt.plot(self.stimulus_value, self.f_baselines, color='blue', label='f_base') + + plt.plot(self.stimulus_value, self.f_infinities, 'o', color='lime', label='f_inf') + plt.plot(x_values, [fu.clipped_line(x, self.f_infinity_fit[0], self.f_infinity_fit[1]) for x in x_values], + color='darkgreen', label='f_inf_fit') + + plt.plot(self.stimulus_value, self.f_zeros, 'o', color='orange', label='f_zero') + popt = self.boltzmann_fit_vars + plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values], + color='red', label='f_0_fit') + + plt.legend() + plt.ylabel("Frequency [Hz]") + if self.using_contrast: + plt.xlabel("Stimulus contrast") + else: + plt.xlabel("Stimulus intensity [mv]") + if savepath is None: + plt.show() + else: + plt.savefig(savepath + "fi_curve.png") + plt.close() + + def __calculate_f_baseline__(self, frequency, buffer=0.025): + delay = self.cell_data.get_delay() + sampling_interval = self.cell_data.get_sampling_interval() + if delay < 0.1: + warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.") + + idx_start = int(buffer/sampling_interval) + idx_end = int((delay-buffer)/sampling_interval) + return np.mean(frequency[idx_start:idx_end]) + + def __calculate_f_zero__(self, frequency, length_of_mean=0.1, buffer=0.025): + stimulus_start = self.cell_data.get_delay() + self.cell_data.get_stimulus_start() + sampling_interval = self.cell_data.get_sampling_interval() + + start_idx = int((stimulus_start - buffer) / sampling_interval) + end_idx = int((stimulus_start + buffer*2) / sampling_interval) + + freq_before = frequency[start_idx-(int(length_of_mean/sampling_interval)):start_idx] + fb_mean = np.mean(freq_before) + fb_std = np.std(freq_before) + + peak_frequency = fb_mean + count = 0 + for i in range(start_idx + 1, end_idx): + if fb_mean-3*fb_std <= frequency[i] <= fb_mean+3*fb_std: + continue + + if abs(frequency[i] - fb_mean) > abs(peak_frequency - fb_mean): + peak_frequency = frequency[i] + count += 1 + + return peak_frequency + + def __calculate_f_infinity__(self, frequency, length=0.2, buffer=0.025): + stimulus_end_time = \ + self.cell_data.get_delay() + self.cell_data.get_stimulus_start() + self.cell_data.get_stimulus_duration() + + start_idx = int((stimulus_end_time - length - buffer) / self.cell_data.get_sampling_interval()) + end_idx = int((stimulus_end_time - buffer) / self.cell_data.get_sampling_interval()) + + return np.mean(frequency[start_idx:end_idx]) + + def get_f_zero_inverse_at_frequency(self, frequency): + b_vars = self.boltzmann_fit_vars + return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3]) + + def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value): + infty_vars = self.f_infinity_fit + return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1]) + + + def get_f_infinity_slope(self): + return self.f_infinity_fit[1] + + def get_fi_curve_slope_at(self, stimulus_value): + fit_vars = self.boltzmann_fit_vars + return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) + + def get_fi_curve_slope_of_straight(self): + fit_vars = self.boltzmann_fit_vars + return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3]) diff --git a/functionalityTests.py b/functionalityTests.py new file mode 100644 index 0000000..25b14fd --- /dev/null +++ b/functionalityTests.py @@ -0,0 +1,44 @@ + +from models.LIFAC import LIFACModel +from stimuli.StepStimulus import StepStimulus +import numpy as np +import matplotlib.pyplot as plt +import functions as fu + + +def test_lifac(): + model = LIFACModel() + stimulus = StepStimulus(0.5, 1, 15) + + for step_size in [0.001, 0.1]: + model.set_variable("step_size", step_size) + + v, spiketimes = model(stimulus, 2) + + plt.plot(np.arange(0, 2, step_size/1000), v, label="step_size:" + str(step_size)) + + plt.xlabel("time in seconds") + plt.ylabel("Voltage") + plt.title("Voltage in the LIFAC-Model with different step sizes") + plt.show() + plt.close() + + +def test_plot_inverses(ficurve): + var = ficurve.boltzmann_fit_vars + + fig, ax1 = plt.subplots(1, 1, figsize=(4.5, 4.5), tight_layout=True) + start = min(ficurve.stimulus_value) + end = max(ficurve.stimulus_value) + x_values = np.arange(start, end, (end-start)/5000) + ax1.plot(x_values, [fu.full_boltzmann(x, var[0], var[1], var[2], var[3]) for x in x_values], label="fit") + ax1.set_ylabel('freq') + ax1.set_xlabel('stimulus') + + start = var[1] + end = var[0] + x_values = np.arange(start, end, (end - start) / 50) + ax1.plot([fu.inverse_full_boltzmann(x, var[0], var[1], var[2], var[3]) for x in x_values], x_values, + '.', c="red", label='inverse') + plt.legend() + plt.show() \ No newline at end of file diff --git a/functions.py b/functions.py new file mode 100644 index 0000000..624ff2f --- /dev/null +++ b/functions.py @@ -0,0 +1,40 @@ + +import numpy as np + + +def exponential_function(x, a, b, c): + return (a-c)*np.exp(-x/b)+c + + +def upper_boltzmann(x, f_max, k, x_zero): + return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None) + + +def full_boltzmann(x, f_max, f_min, k, x_zero): + return (f_max-f_min) * (1 / (1 + np.power(np.e, -k * (x - x_zero)))) + f_min + + +def full_boltzmann_straight_slope(f_max, f_min, k, x_zero=0): + return (f_max-f_min)*k*1/2 + + +def derivative_full_boltzmann(x, f_max, f_min, k, x_zero): + return (f_max - f_min) * k * np.power(np.e, -k * (x - x_zero)) / (1 + np.power(np.e, -k * (x - x_zero))**2) + + +def inverse_full_boltzmann(x, f_max, f_min, k, x_zero): + if x < f_min or x > f_max: + raise ValueError("Value undefined in inverse_full_boltzmann") + + return -(np.log((f_max-f_min) / (x - f_min) - 1) / k) + x_zero + + +def clipped_line(x, a, b): + return np.clip(a+b*x, 0, None) + + +def inverse_clipped_line(x, a, b): + if clipped_line(x, a, b) == 0: + raise ValueError("Value undefined in inverse_clipped_line.") + + return (x-a)/b diff --git a/generalTests.py b/generalTests.py new file mode 100644 index 0000000..50900fd --- /dev/null +++ b/generalTests.py @@ -0,0 +1,54 @@ +import numpy as np +import matplotlib.pyplot as plt +from models.LeakyIntegrateFireModel import LIFModel + + +# def calculate_step(current_v, tau, i_b, step_size=0.01): + # return current_v + (step_size * (-current_v + mem_res * i_b)) / tau + +def function_e(x): + return (0-15) * np.e**(-x/1) + 15 + +# x_values = np.arange(0, 5, 0.01) +# plt.plot(x_values, [function_e(x) for x in x_values]) +# plt.show() + + +# def function_f(i_base, tau=1, threshold=10, reset=0): +# return -1/(tau*np.log((threshold-i_base)/(reset - i_base))) +# +# x_values = np.arange(0, 20, 0.001) +# plt.plot(x_values, [function_f(x) for x in x_values]) +# plt.show() + + +# LIF test: +# Rm = 100 MOhm, Cm = 200pF +step_size = 0.01 # ms +mem_res = 100*1000000 +tau = 1 +base_freq = 30 +v_threshold = 10 +base_input = -(- v_threshold / (np.e**(-1/(base_freq*tau))) + 1) / mem_res + +stim1 = int(1000/step_size) * [base_input] + +stimulus = [] +stimulus.extend(stim1) + +lif = LIFModel(mem_res, tau, 0, 0, stimulus, 10) + +voltage, spikes_b = lif.calculate_response() +y_spikes = [] +x_spikes = [] +for i in range(len(spikes_b)): + if spikes_b[i]: + y_spikes.append(10.5) + x_spikes.append(i*step_size) + +time = np.arange(0, 1000, step_size) +plt.plot(time, voltage) +plt.plot(x_spikes, y_spikes, 'o') +plt.show() +plt.close() + diff --git a/helperFunctions.py b/helperFunctions.py new file mode 100644 index 0000000..e17c063 --- /dev/null +++ b/helperFunctions.py @@ -0,0 +1,214 @@ +import os +import pyrelacs.DataLoader as dl +import numpy as np +import matplotlib.pyplot as plt +from warnings import warn + +def get_subfolder_paths(basepath): + subfolders = [] + for content in os.listdir(basepath): + content_path = basepath + content + if os.path.isdir(content_path): + subfolders.append(content_path) + + return sorted(subfolders) + + +def get_traces(directory, trace_type, repro): + # trace_type = 1: Voltage p-unit + # trace_type = 2: EOD + # trace_type = 3: local EOD ~(EOD + stimulus) + # trace_type = 4: Stimulus + + load_iter = dl.iload_traces(directory, repro=repro) + + time_traces = [] + value_traces = [] + + nothing = True + + for info, key, time, x in load_iter: + nothing = False + time_traces.append(time) + value_traces.append(x[trace_type-1]) + + if nothing: + print("iload_traces found nothing for the BaselineActivity repro!") + + return time_traces, value_traces + + +def get_all_traces(directory, repro): + load_iter = dl.iload_traces(directory, repro=repro) + + time_traces = [] + v1_traces = [] + eod_traces = [] + local_eod_traces = [] + stimulus_traces = [] + + nothing = True + + for info, key, time, x in load_iter: + nothing = False + time_traces.append(time) + v1_traces.append(x[0]) + eod_traces.append(x[1]) + local_eod_traces.append(x[2]) + stimulus_traces.append(x[3]) + print(info) + + traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces] + + if nothing: + print("iload_traces found nothing for the BaselineActivity repro!") + + return time_traces, traces + + +def merge_similar_intensities(intensities, spiketimes, trans_amplitudes): + i = 0 + + diffs = np.diff(sorted(intensities)) + margin = np.mean(diffs) * 0.6666 + + while True: + if i >= len(intensities): + break + intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin) + i += 1 + + # Sort the lists so that intensities are increasing + x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))] + intensities = x[0] + spiketimes = x[1] + + return intensities, spiketimes, trans_amplitudes + + +def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin): + intensity = intensities[index] + + indices_to_merge = [] + for i in range(index+1, len(intensities)): + if np.abs(intensities[i]-intensity) < margin: + indices_to_merge.append(i) + + if len(indices_to_merge) != 0: + indices_to_merge.reverse() + + trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge] + + all_the_same = True + for j in range(1, len(trans_amplitude_values)): + if not trans_amplitude_values[0] == trans_amplitude_values[j]: + all_the_same = False + break + + if all_the_same: + for idx in indices_to_merge: + del trans_amplitudes[idx] + else: + raise RuntimeError("Trans_amplitudes not the same....") + for idx in indices_to_merge: + spiketimes[index].extend(spiketimes[idx]) + del spiketimes[idx] + del intensities[idx] + + return intensities, spiketimes, trans_amplitudes + + +def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval): + times = [] + mean_frequencies = [] + + for i in range(len(spiketimes)): + trial_times = [] + trial_means = [] + for j in range(len(spiketimes[i])): + time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval) + trial_means.append(isi_freq) + trial_times.append(time) + + time, mean_freq = calculate_mean_frequency(trial_times, trial_means) + times.append(time) + mean_frequencies.append(mean_freq) + + return times, mean_frequencies + + +def calculate_isi_frequency(spiketimes, time_start, sampling_interval): + first_isi = spiketimes[0] - time_start + isis = [first_isi] + isis.extend(np.diff(spiketimes)) + time = np.arange(time_start, spiketimes[-1], sampling_interval) + + full_frequency = [] + i = 0 + for isi in isis: + if isi == 0: + warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()") + continue + freq = 1 / isi + frequency_step = int(round(isi * (1 / sampling_interval))) * [freq] + full_frequency.extend(frequency_step) + i += 1 + if len(full_frequency) != len(time): + if abs(len(full_frequency) - len(time)) == 1: + warn("FiCurve:__calculate_mean_isi_frequency__():\nFrequency and time were one of in length!") + if len(full_frequency) < len(time): + time = time[:len(full_frequency)] + else: + full_frequency = full_frequency[:len(time)] + else: + print("ERROR PRINT:") + print("freq:", len(full_frequency), "time:", len(time), "diff:", len(full_frequency) - len(time)) + raise RuntimeError("FiCurve:__calculate_mean_isi_frequency__():\n" + "Frequency and time are not the same length!") + + return time, full_frequency + + +def calculate_mean_frequency(trial_times, trial_freqs): + lengths = [len(t) for t in trial_times] + shortest = min(lengths) + + time = trial_times[0][0:shortest] + shortend_freqs = [freq[0:shortest] for freq in trial_freqs] + mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)] + + return time, mean_freq + + +def crappy_smoothing(signal:list, window_size:int = 5) -> list: + smoothed = [] + + for i in range(len(signal)): + k = window_size + if i < window_size: + k = i + j = window_size + if i + j > len(signal): + j = len(signal) - i + + smoothed.append(np.mean(signal[i-k:i+j])) + + return smoothed + + +def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None): + contrast = cell_data.get_fi_contrasts() + time_axes = cell_data.get_time_axes_mean_frequencies() + mean_freqs = cell_data.get_mean_isi_frequencies() + + if indices is None: + indices = np.arange(len(contrast)) + + for i in indices: + plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2))) + + if save_path is None: + plt.show() + else: + plt.savefig(save_path + "mean_frequency_curves.png") + plt.close() \ No newline at end of file diff --git a/introduction/introductionBaseline.py b/introduction/introductionBaseline.py new file mode 100644 index 0000000..b28a485 --- /dev/null +++ b/introduction/introductionBaseline.py @@ -0,0 +1,283 @@ + +import pyrelacs.DataLoader as dl +import numpy as np +import matplotlib.pyplot as plt +from IPython import embed +import os +import helperFunctions as hf +from thunderfish.eventdetection import detect_peaks + + +SAVEPATH = "" + + +def get_savepath(): + global SAVEPATH + return SAVEPATH + + +def set_savepath(new_path): + global SAVEPATH + SAVEPATH = new_path + + +def main(): + for folder in hf.get_subfolder_paths("data/"): + filepath = folder + "/basespikes1.dat" + set_savepath("figures/" + folder.split('/')[1] + "/") + + print("Folder:", folder) + + if not os.path.exists(get_savepath()): + os.makedirs(get_savepath()) + + spiketimes = [] + + ran = False + for metadata, key, data in dl.iload(filepath): + ran = True + spikes = data[:, 0] + spiketimes.append(spikes) # save for calculation of vector strength + metadata = metadata[0] + #print(metadata) + # print('firing frequency1:', metadata['firing frequency1']) + # print(mean_firing_rate(spikes)) + + # print('Coefficient of Variation (CV):', metadata['CV1']) + # print(calculate_coefficient_of_variation(spikes)) + + if not ran: + print("------------ DIDN'T RUN") + + isi_histogram(spiketimes) + + times, eods = hf.get_traces(folder, 2, 'BaselineActivity') + times, v1s = hf.get_traces(folder, 1, 'BaselineActivity') + + vs = calculate_vector_strength(times, eods, spiketimes, v1s) + + # print("Calculated vector strength:", vs) + + +def mean_firing_rate(spiketimes): + # mean firing rate (number of spikes per time) + return len(spiketimes)/spiketimes[-1]*1000 + + +def calculate_coefficient_of_variation(spiketimes): + # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes)) + isi = np.diff(spiketimes) + std = np.std(isi) + mean = np.mean(isi) + + return std/mean + + +def isi_histogram(spiketimes): + # ISI histogram (play around with binsize! < 1ms) + + isi = [] + for spike_list in spiketimes: + isi.extend(np.diff(spike_list)) + maximum = max(isi) + bins = np.arange(0, maximum*1.01, 0.1) + + plt.title('Phase locking of ISI without stimulus') + plt.xlabel('ISI in ms') + plt.ylabel('Count') + plt.hist(isi, bins=bins) + plt.savefig(get_savepath() + 'phase_locking_without_stimulus.png') + plt.close() + + +def calculate_vector_strength(times, eods, spiketimes, v1s): + # Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8 + # dl.iload_traces(repro='BaselineActivity') + + relative_spike_times = [] + eod_durations = [] + + if len(times) == 0: + print("-----LENGTH OF TIMES = 0") + + for recording in range(len(times)): + + rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketimes[recording]) + relative_spike_times.extend(rel_spikes) + eod_durations.extend(eod_durs) + + vs = __vector_strength__(rel_spikes, eod_durs) + phases = calculate_phases(rel_spikes, eod_durs) + plot_polar(phases, "test_phase_locking_" + str(recording) + "_with_vs:" + str(round(vs, 3)) + ".png") + + print("VS of recording", recording, ":", vs) + + plot_phaselocking_testfigures(times[recording], eods[recording], spiketimes[recording], v1s[recording]) + + return __vector_strength__(relative_spike_times, eod_durations) + + +def eods_around_spikes(time, eod, spiketimes): + eod_durations = [] + relative_spike_times = [] + + for spike in spiketimes: + index = spike * 20 # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx + + if index != np.round(index): + print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index) + continue + index = int(index) + + start_time, end_time = search_eod_start_and_end_times(time, eod, index) + + eod_durations.append(end_time-start_time) + relative_spike_times.append(spike/1000 - start_time) + + return relative_spike_times, eod_durations + + +def search_eod_start_and_end_times(time, eod, index): + # TODO might break if a spike is in the cut off first or last eod! + + # search start_time: + previous = index + working_idx = index-1 + while True: + if eod[working_idx] < 0 < eod[previous]: + first_value = eod[working_idx] + second_value = eod[previous] + + dif = second_value - first_value + part = np.abs(first_value/dif) + + time_dif = np.abs(time[previous] - time[working_idx]) + start_time = time[working_idx] + time_dif*part + + break + + previous = working_idx + working_idx -= 1 + + # search end_time + previous = index + working_idx = index + 1 + while True: + if eod[previous] < 0 < eod[working_idx]: + first_value = eod[previous] + second_value = eod[working_idx] + + dif = second_value - first_value + part = np.abs(first_value / dif) + + time_dif = np.abs(time[previous] - time[working_idx]) + end_time = time[working_idx] + time_dif * part + + break + + previous = working_idx + working_idx += 1 + + return start_time, end_time + + +def search_closest_index(array, value, start=0, end=-1): + # searches the array to find the closest value in the array to the given value and returns its index. + # expects sorted array! + # start hast to be smaller than end + + if end == -1: + end = len(array)-1 + + while True: + if end-start <= 1: + return end if np.abs(array[end]-value) < np.abs(array[start]-value) else start + + middle = int(np.floor((end-start)/2)+start) + if array[middle] == value: + return middle + elif array[middle] > value: + end = middle + continue + else: + start = middle + continue + + +def __vector_strength__(relative_spike_times, eod_durations): + # adapted from Ramona + + n = len(relative_spike_times) + if n == 0: + return 0 + + phase_times = np.zeros(n) + + for i in range(n): + phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi + vs = np.sqrt((1 / n * sum(np.cos(phase_times))) ** 2 + (1 / n * sum(np.sin(phase_times))) ** 2) + + return vs + + +def calculate_phases(relative_spike_times, eod_durations): + phase_times = np.zeros(len(relative_spike_times)) + + for i in range(len(relative_spike_times)): + phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi + + return phase_times + + +def plot_polar(phases, name=""): + fig = plt.figure() + ax = fig.add_subplot(111, polar=True) + # r = np.arange(0, 1, 0.001) + # theta = 2 * 2 * np.pi * r + # line, = ax.plot(theta, r, color='#ee8d18', lw=3) + bins = np.arange(0, np.pi*2, 0.05) + ax.hist(phases, bins=bins) + if name == "": + plt.show() + else: + plt.savefig(get_savepath() + name) + plt.close() + + +def plot_phaselocking_testfigures(time, eod, spiketimes, v1): + eod_start_times = [] + eod_end_times = [] + + for spike in spiketimes: + index = spike * 20 # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx + + if index != np.round(index): + print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index) + continue + index = int(index) + + start_time, end_time = search_eod_start_and_end_times(time, eod, index) + + eod_start_times.append(start_time) + eod_end_times.append(end_time) + + cutoff_in_sec = 2 + sampling = 20000 + max_idx = cutoff_in_sec*sampling + spikes_part = [x/1000 for x in spiketimes if x/1000 < cutoff_in_sec] + count_spikes = len(spikes_part) + print(spiketimes) + print(len(spikes_part)) + + x_axis = time[0:max_idx] + plt.plot(spikes_part, np.ones(len(spikes_part))*-20, 'o') + plt.plot(x_axis, v1[0:max_idx]) + plt.plot(eod_start_times[: count_spikes], np.zeros(count_spikes), 'o') + plt.plot(eod_end_times[: count_spikes], np.zeros(count_spikes), 'o') + + plt.show() + plt.close() + + +if __name__ == '__main__': + main() diff --git a/introduction/introductionFICurve.py b/introduction/introductionFICurve.py new file mode 100644 index 0000000..1042f8f --- /dev/null +++ b/introduction/introductionFICurve.py @@ -0,0 +1,298 @@ +import numpy as np +import matplotlib.pyplot as plt +import pyrelacs.DataLoader as dl +import os +import helperFunctions as hf +from IPython import embed +from scipy.optimize import curve_fit +import warnings + +SAMPLING_INTERVAL = 1/20000 +STIMULUS_START = 0 +STIMULUS_DURATION = 0.400 +PRE_DURATION = 0.250 +TOTAL_DURATION = 1.25 + + +def main(): + for folder in hf.get_subfolder_paths("data/"): + filepath = folder + "/fispikes1.dat" + set_savepath("figures/" + folder.split('/')[1] + "/") + print("Folder:", folder) + + if not os.path.exists(get_savepath()): + os.makedirs(get_savepath()) + + spiketimes = [] + intensities = [] + index = -1 + for metadata, key, data in dl.iload(filepath): + # embed() + if len(metadata) != 0: + + metadata_index = 0 + if '----- Control --------------------------------------------------------' in metadata[0].keys(): + metadata_index = 1 + + print(metadata) + i = float(metadata[metadata_index]['intensity'][:-2]) + intensities.append(i) + spiketimes.append([]) + index += 1 + + spiketimes[index].append(data[:, 0]/1000) + + intensities, spiketimes = hf.merge_similar_intensities(intensities, spiketimes) + + # Sort the lists so that intensities are increasing + x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))] + intensities = x[0] + spiketimes = x[1] + + mean_frequencies = calculate_mean_frequencies(intensities, spiketimes) + popt, pcov = fit_exponential(intensities, mean_frequencies) + plot_frequency_curve(intensities, mean_frequencies) + + f_baseline = calculate_f_baseline(mean_frequencies) + f_infinity = calculate_f_infinity(mean_frequencies) + f_zero = calculate_f_zero(mean_frequencies) + + # plot_fi_curve(intensities, f_baseline, f_zero, f_infinity) + + +# TODO !! +def fit_exponential(intensities, mean_frequencies): + start_idx = int((PRE_DURATION + STIMULUS_START+0.005) / SAMPLING_INTERVAL) + end_idx = int((PRE_DURATION + STIMULUS_START + 0.1) / SAMPLING_INTERVAL) + time_constants = [] + #print(start_idx, end_idx) + + popts = [] + pcovs = [] + for i in range(len(mean_frequencies)): + freq = mean_frequencies[i] + y_values = freq[start_idx:end_idx+1] + x_values = np.arange(start_idx*SAMPLING_INTERVAL, end_idx*SAMPLING_INTERVAL, SAMPLING_INTERVAL) + try: + popt, pcov = curve_fit(exponential_function, x_values, y_values, p0=(1/(np.power(1, 10)), .5, 50, 180), maxfev=10000) + except RuntimeError: + print("RuntimeError happened in fit_exponential.") + continue + #print(popt) + #print(pcov) + #print() + + popts.append(popt) + pcovs.append(pcov) + + plt.plot(np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL), freq) + plt.plot(x_values-PRE_DURATION, [exponential_function(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values]) + # plt.show() + save_path = get_savepath() + "exponential_fits/" + if not os.path.exists(save_path): + os.makedirs(save_path) + plt.savefig(save_path + "fit_intensity:" + str(round(intensities[i], 4)) + ".png") + plt.close() + + return popts, pcovs + + +def calculate_mean_frequency(freqs): + mean_freq = [sum(e) / len(e) for e in zip(*freqs)] + + return mean_freq + + +def gaussian_kernel(sigma, dt): + x = np.arange(-4. * sigma, 4. * sigma, dt) + y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma + return y + + +def calculate_kernel_frequency(spiketimes, time, sampling_interval): + sp = spiketimes + t = time # Probably goes from -200ms to some amount of ms in the positive ~1200? + dt = sampling_interval + kernel_width = 0.01 # kernel width is a time in seconds how sharp the frequency should be counted + + binary = np.zeros(t.shape) + spike_indices = ((sp - t[0]) / dt).astype(int) + binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1 + g = gaussian_kernel(kernel_width, dt) + + rate = np.convolve(binary, g, mode='same') + + return rate + + +def calculate_isi_frequency(spiketimes, time): + first_isi = spiketimes[0] - (-PRE_DURATION) # diff to the start at 0 + last_isi = TOTAL_DURATION - spiketimes[-1] # diff from the last spike to the end of time :D + isis = [first_isi] + isis.extend(np.diff(spiketimes)) + isis.append(last_isi) + + if np.isnan(first_isi): + print(spiketimes[:10]) + print(isis[0:10]) + quit() + + rate = [] + for isi in isis: + if isi == 0: + print("probably a problem") + isi = 0.0000000001 + freq = 1/isi + frequency_step = int(round(isi*(1/SAMPLING_INTERVAL)))*[freq] + rate.extend(frequency_step) + + + #plt.plot((np.arange(len(rate))-PRE_DURATION)/(1/SAMPLING_INTERVAL), rate) + #plt.plot([sum(isis[:i+1]) for i in range(len(isis))], [200 for i in isis], 'o') + #plt.plot(time, [100 for t in time]) + #plt.show() + + if len(rate) != len(time): + if "12-13-af" in get_savepath(): + warnings.warn("preStimulus duration > 0 still not supported") + return [1]*len(time) + else: + print(len(rate), len(time), len(rate) - len(time)) + print(rate) + print(isis) + print("Quitting because time and rate aren't the same length") + quit() + + return rate + + +def calculate_mean_frequencies(intensities, spiketimes): + time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL) + + mean_frequencies = [] + for i in range(len(intensities)): + freqs = [] + for spikes in spiketimes[i]: + if len(spikes) < 2: + continue + freq = calculate_isi_frequency(spikes, time) + freqs.append(freq) + + mf = calculate_mean_frequency(freqs) + mean_frequencies.append(mf) + + return mean_frequencies + + +def calculate_f_baseline(mean_frequencies): + buffer_time = 0.05 + start_idx = int(0.05/SAMPLING_INTERVAL) + end_idx = int((PRE_DURATION - STIMULUS_START - buffer_time)/SAMPLING_INTERVAL) + + f_zeros = [] + for freq in mean_frequencies: + f_0 = np.mean(freq[start_idx:end_idx]) + f_zeros.append(f_0) + + return f_zeros + + +def calculate_f_infinity(mean_frequencies): + buffer_time = 0.05 + start_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - 0.15 - buffer_time) / SAMPLING_INTERVAL) + end_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - buffer_time) / SAMPLING_INTERVAL) + + f_infinity = [] + for freq in mean_frequencies: + f_inf = np.mean(freq[start_idx:end_idx]) + f_infinity.append(f_inf) + + return f_infinity + + +def calculate_f_zero(mean_frequencies): + buffer_time = 0.1 + start_idx = int((PRE_DURATION + STIMULUS_START - buffer_time) / SAMPLING_INTERVAL) + end_idx = int((PRE_DURATION + STIMULUS_START + buffer_time) / SAMPLING_INTERVAL) + f_peaks = [] + for freq in mean_frequencies: + fp = np.mean(freq[start_idx-500:start_idx]) + for i in range(start_idx+1, end_idx): + if abs(freq[i] - freq[start_idx]) > abs(fp - freq[start_idx]): + fp = freq[i] + f_peaks.append(fp) + return f_peaks + + +def plot_fi_curve(intensities, f_baseline, f_zero, f_infinity): + plt.plot(intensities, f_baseline, label="f_baseline") + plt.plot(intensities, f_zero, 'o', label="f_zero") + plt.plot(intensities, f_infinity, label="f_infinity") + + max_f0 = float(max(f_zero)) + mean_int = float(np.mean(intensities)) + start_k = float(((f_zero[-1] - f_zero[0]) / (intensities[-1] - intensities[0])*4)/f_zero[-1]) + + popt, pcov = curve_fit(fill_boltzmann, intensities, f_zero, p0=(max_f0, start_k, mean_int), maxfev=10000) + print(popt) + min_x = min(intensities) + max_x = max(intensities) + step = (max_x - min_x) / 5000 + x_values_boltzmann_fit = np.arange(min_x, max_x, step) + plt.plot(x_values_boltzmann_fit, [fill_boltzmann(i, popt[0], popt[1], popt[2]) for i in x_values_boltzmann_fit], label='fit') + + plt.title("FI-Curve") + plt.ylabel("Frequency in Hz") + plt.xlabel("Intensity in mV") + plt.legend() + # plt.show() + plt.savefig(get_savepath() + "fi_curve.png") + plt.close() + + +def plot_frequency_curve(intensities, mean_frequencies): + colors = ["red", "green", "blue", "violet", "orange", "grey"] + + time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL) + + for i in range(len(intensities)): + plt.plot(time, mean_frequencies[i], color=colors[i % 6], label=str(intensities[i])) + + plt.plot((0, 0), (0, 500), color="black") + plt.plot((0.4, 0.4), (0, 500), color="black") + plt.legend() + plt.xlabel("Time in seconds") + plt.ylabel("Frequency in Hz") + plt.title("Frequency curve") + + plt.savefig(get_savepath() + "mean_frequency_curves.png") + plt.close() + + +def exponential_function(x, a, b, c, d): + return a*np.exp(-c*(x-b))+d + + +def upper_boltzmann(x, f_max, k, x_zero): + return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None) + + +def fill_boltzmann(x, f_max, k, x_zero): + return f_max * (1 / (1 + np.power(np.e, -k * (x - x_zero)))) + + +SAVEPATH = "" + + +def get_savepath(): + global SAVEPATH + return SAVEPATH + + +def set_savepath(new_path): + global SAVEPATH + SAVEPATH = new_path + + +if __name__ == '__main__': + main() diff --git a/introduction/janExample.py b/introduction/janExample.py new file mode 100644 index 0000000..272aaa6 --- /dev/null +++ b/introduction/janExample.py @@ -0,0 +1,20 @@ +import pyrelacs.DataLoader as dl + +for metadata, key, data in dl.iload('2012-06-27-ah-invivo-1/basespikes1.dat'): + print(data.shape) + break + +# mean firing rate (number of spikes per time) +# CV (stdev of ISI divided by mean ISI (np.diff(spiketimes)) +# ISI histogram (play around with binsize! < 1ms) +# Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8 +# dl.iload_traces(repro='BaselineActivity') + +def test(): + for metadata, key, data in dl.iload('data/2012-06-27-ah-invivo-1/basespikes1.dat'): + print(data.shape) + for i in metadata: + for key in i.keys(): + print(key, ":", i[key]) + + break \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..244ae25 --- /dev/null +++ b/main.py @@ -0,0 +1,61 @@ + +from FiCurve import FICurve +from CellData import icelldata_of_dir +import os +import helperFunctions as hf +from AdaptionCurrent import Adaption +from models.NeuronModel import NeuronModel +from functionalityTests import * +# TODO command line interface needed/nice ? + + +def main(): + run_tests() + quit() + for cell_data in icelldata_of_dir("./data/"): + print() + print(cell_data.get_data_path()) + + model = NeuronModel(cell_data) + + x_values = np.arange(0, 1000, 0.01) + stimulus = [0]*int(200/0.01) + stimulus.extend([0.19]*int(400/0.01)) + stimulus.extend([0]*int(400/0.01)) + + v, spikes = model.simulate(0, 1000, stimulus) + + # plt.plot(x_values, v) + + spikes = [s/1000 for s in spikes] + time, freq = hf.calculate_isi_frequency(spikes, 0, 0.01/1000) + + plt.plot(time, freq) + plt.show() + + quit() + continue + + figures_save_path = "./figures/" + os.path.basename(cell_data.get_data_path()) + "/" + ficurve = FICurve(cell_data) + ficurve.plot_fi_curve(figures_save_path) + + adaption = Adaption(cell_data, ficurve) + adaption.plot_exponential_fits(figures_save_path + "exponential_fits/", delete_previous=True) + + for i in range(len(adaption.exponential_fit_vars)): + if len(adaption.exponential_fit_vars[i]) == 0: + continue + tau = round(adaption.exponential_fit_vars[i][1]*1000, 2) + contrast = round(ficurve.stimulus_value[i], 3) + # print(tau, "ms - tau_eff at", contrast, "contrast") + + # test_plot_inverses(ficurve) + print("Chosen tau [ms]:", adaption.tau_real) + + +def run_tests(): + test_lifac() + +if __name__ == '__main__': + main() diff --git a/models/LIFAC.py b/models/LIFAC.py new file mode 100644 index 0000000..61b1daf --- /dev/null +++ b/models/LIFAC.py @@ -0,0 +1,88 @@ + +from stimuli.AbstractStimulus import AbstractStimulus +import numpy as np + + +class LIFACModel: + # all times in milliseconds + KEYS = ["mem_res", "mem_tau", "v_base", "v_zero", "threshold", "step_size", "delta_a", "tau_a"] + VALUES = [100 * 1000000, 0.1 * 200, 0, 0, 10, 0.01, 1, 200] + + # membrane time constant tau = mem_cap*mem_res + def __init__(self, params: dict = None): + self.parameters = {} + if params is None: + self._set_default_parameters() + else: + self._test_given_parameters(params) + self.set_parameters(params) + + self.last_v = [] + self.last_adaption = [] + self.last_spiketimes = [] + + def __call__(self, stimulus: AbstractStimulus, total_time_s): + output_voltage = [] + adaption = [] + spiketimes = [] + current_v = self.parameters["v_zero"] + current_a = 0 + + for time_point in np.arange(0, total_time_s*1000, self.parameters["step_size"]): + v_next = self._calculate_voltage_step(current_v, stimulus.value_at_time_in_ms(time_point) - current_a) + a_next = self._calculate_adaption_step(current_a) + + if v_next > self.parameters["threshold"]: + v_next = self.parameters["v_base"] + spiketimes.append(time_point/1000) + a_next += self.parameters["delta_a"] + + output_voltage.append(v_next) + adaption.append(a_next) + + current_v = v_next + current_a = a_next + + self.last_v = output_voltage + self.last_adaption = adaption + self.last_spiketimes = spiketimes + + return output_voltage, spiketimes + + def _calculate_voltage_step(self, current_v, input_v): + v_base = self.parameters["v_base"] + step_size = self.parameters["step_size"] + # mem_res = self.parameters["mem_res"] + mem_tau = self.parameters["mem_tau"] + return current_v + (step_size * (v_base - current_v + input_v)) / mem_tau + + def _calculate_adaption_step(self, current_a): + step_size = self.parameters["step_size"] + return current_a + (step_size * (-current_a)) / self.parameters["tau_a"] + + def set_parameters(self, params): + for k in params.keys(): + self.parameters[k] = params[k] + + for i in range(len(self.KEYS)): + if self.KEYS[i] not in self.parameters.keys(): + self.parameters[self.KEYS[i]] = self.VALUES[i] + + def get_parameters(self): + return self.parameters + + def set_variable(self, key, value): + if key not in self.KEYS: + raise ValueError("Given key is unknown!\n" + "Please check spelling and refer to list LIFAC.KEYS.") + self.parameters[key] = value + + def _set_default_parameters(self): + for i in range(len(self.KEYS)): + self.parameters[self.KEYS[i]] = self.VALUES[i] + + def _test_given_parameters(self, params): + for k in params.keys(): + if k not in self.KEYS: + err_msg = "Unknown key in the given parameters:" + str(k) + raise ValueError(err_msg) \ No newline at end of file diff --git a/models/LeakyIntegrateFireModel.py b/models/LeakyIntegrateFireModel.py new file mode 100644 index 0000000..567f155 --- /dev/null +++ b/models/LeakyIntegrateFireModel.py @@ -0,0 +1,37 @@ + + +class LIFModel: + # all times in milliseconds + def __init__(self, mem_res, mem_tau, v_base, v_zero, input_current, threshold, input_offset=0, step_size=0.01): + self.mem_res = mem_res + # self.membrane_capacitance = mem_cap + self.mem_tau = mem_tau # membrane time constant tau = mem_cap*mem_res + self.v_base = v_base + self.v_zero = v_zero + self.threshold = threshold + + self.step_size = step_size + self.input_current = input_current + self.input_offset = input_offset + + def calculate_response(self): + output_voltage = [self.v_zero] + spikes = [] + + for idx in range(1, len(self.input_current)): + v_next = self.__calculate_next_step__(output_voltage[idx-1], self.input_current[idx-1]) + if v_next > self.threshold: + v_next = self.v_base + spikes.append(True) + else: + spikes.append(False) + output_voltage.append(v_next) + + return output_voltage, spikes + + def set_input_current(self, input_current, offset=0): + self.input_current = input_current + self.input_offset = offset + + def __calculate_next_step__(self, current_v, input_i): + return current_v + (self.step_size * (self.v_base - current_v + self.mem_res * input_i)) / self.mem_tau diff --git a/models/NeuronModel.py b/models/NeuronModel.py new file mode 100644 index 0000000..c912f1b --- /dev/null +++ b/models/NeuronModel.py @@ -0,0 +1,110 @@ + +from CellData import CellData +from FiCurve import FICurve +from AdaptionCurrent import Adaption +import numpy as np +import matplotlib.pyplot as plt + + +class NeuronModel: + KEYS = ["mem_res", "mem_tau", "v_base", "v_zero", "threshold", "step_size"] + VALUES = [100 * 1000000, 0.1 * 200, 0, 0, 10, 0.01] + + def __init__(self, cell_data: CellData, variables: dict = None): + self.cell_data = cell_data + self.fi_curve = FICurve(cell_data) + self.adaption = Adaption(cell_data, self.fi_curve) + + if variables is not None: + self._test_given_variables(variables) + self.variables = variables + else: + self.variables = {} + self._add_standard_variables() + + def __call__(self, stimulus): + raise NotImplementedError("Soon. sorry!") + + def _approximate_variables_from_data(self): + # TODO don't return but save in class in some form! approximate/calculate other variables? + base_input = self._calculate_input_fro_base_frequency() + return base_input + + def simulate(self, start_v, time_in_ms, stimulus): + response = [] + spikes = [] + current_v = start_v + current_a = 0 + base_input = self._calculate_input_fro_base_frequency() + + adaption_values = [] + a_infties = [] + print("base input:", base_input) + for time_step in np.arange(0, time_in_ms, self.variables["step_size"]): + stimulus_input = stimulus[int(time_step/self.variables["step_size"])] - current_a + + new_v = self._calculate_next_step(current_v, current_a*base_input, base_input + base_input*stimulus_input) + new_a, a_infty = self._calculate_adaption_step(current_a, stimulus_input) + + if new_v > self.variables["threshold"]: + new_v = self.variables["v_base"] + spikes.append(time_step) + response.append(new_v) + + adaption_values.append(current_a) + a_infties.append(a_infty) + current_v = new_v + current_a = new_a + + plt.title("Adaption variable") + plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), np.array(adaption_values), label="adaption") + plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), np.array(a_infties), label="a_inf") + plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), stimulus, label="stimulus") + plt.legend() + plt.xlabel("time in ms") + plt.ylabel("value as contrast?") + plt.show() + plt.close() + + return response, spikes + + def _calculate_next_step(self, current_v, current_a, input_v): + step_size = self.variables["step_size"] + v_base = self.variables["v_base"] + mem_tau = self.variables["mem_tau"] + + return current_v + (step_size * (- current_v + v_base + input_v - current_a)) / mem_tau + + def _calculate_adaption_step(self, current_a, stimulus_input): + step_size = self.variables["step_size"] + tau_a = self.adaption.tau_real + f_infty_freq = self.fi_curve.get_f_infinity_frequency_at_stimulus_value(stimulus_input) + a_infinity = stimulus_input - self.fi_curve.get_f_zero_inverse_at_frequency(f_infty_freq) + return current_a + (step_size * (- current_a + a_infinity)) / tau_a, a_infinity + + def set_variable(self, key, value): + if key not in self.KEYS: + raise ValueError("Given key is unknown!\n" + "Please check spelling and refer to list NeuronModel.KEYS.") + self.variables[key] = value + + def set_variables(self, variables: dict): + self._test_given_variables(variables) + + for k in variables.keys(): + self.variables[k] = variables[k] + + def _calculate_input_fro_base_frequency(self): + return - self.variables["threshold"] / ( + np.e ** (-1 / (self.cell_data.get_base_frequency()/1000 * self.variables["mem_tau"])) - 1) + + def _test_given_variables(self, variables: dict): + for k in variables.keys(): + if k not in self.KEYS: + raise ValueError("Unknown key in given model variables. \n" + "Please check spelling and refer to list NeuronModel.KEYS.") + + def _add_standard_variables(self): + for i in range(len(self.KEYS)): + if self.KEYS[i] not in self.variables: + self.variables[self.KEYS[i]] = self.VALUES[i] diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stimuli/AbstractStimulus.py b/stimuli/AbstractStimulus.py new file mode 100644 index 0000000..0745fa5 --- /dev/null +++ b/stimuli/AbstractStimulus.py @@ -0,0 +1,8 @@ + +class AbstractStimulus: + + def value_at_time_in_ms(self, time_point): + raise NotImplementedError("This is an abstract class!") + + def value_at_time_in_s(self, time_point): + raise NotImplementedError("This is an abstract class!") diff --git a/stimuli/StepStimulus.py b/stimuli/StepStimulus.py new file mode 100644 index 0000000..261aefb --- /dev/null +++ b/stimuli/StepStimulus.py @@ -0,0 +1,26 @@ + +from stimuli.AbstractStimulus import AbstractStimulus + + +class StepStimulus(AbstractStimulus): + + def __init__(self, start, duration, value, base_value=0, seconds=True): + self.start = 0 + self.duration = 0 + self.base_value = base_value + self.value = value + if seconds: + self.start = start + self.duration = duration + else: + self.start = start / 1000 + self.duration = duration / 1000 + + def value_at_time_in_ms(self, time_point): + return self.value_at_time_in_s(time_point/1000) + + def value_at_time_in_s(self, time_point): + if self.start <= time_point <= self.start + self.duration: + return self.value + else: + return self.base_value diff --git a/stimuli/__init__.py b/stimuli/__init__.py new file mode 100644 index 0000000..e69de29