commit all existing code

This commit is contained in:
a.ott 2019-12-20 13:33:34 +01:00
parent e7ce44273e
commit f5dc213e42
19 changed files with 1997 additions and 0 deletions

168
AdaptionCurrent.py Normal file
View File

@ -0,0 +1,168 @@
from FiCurve import FICurve
from CellData import CellData
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import os
import numpy as np
import functions as fu
class Adaption:
def __init__(self, cell_data: CellData, fi_curve: FICurve = None):
self.cell_data = cell_data
if fi_curve is None:
self.fi_curve = FICurve(cell_data)
else:
self.fi_curve = fi_curve
# [[a, tau_eff, c], [], [a, tau_eff, c], ...]
self.exponential_fit_vars = []
self.tau_real = []
self.fit_exponential()
self.calculate_tau_from_tau_eff()
def fit_exponential(self, length_of_fit=0.05):
mean_frequencies = self.cell_data.get_mean_isi_frequencies()
time_axes = self.cell_data.get_time_axes_mean_frequencies()
for i in range(len(mean_frequencies)):
start_idx = self.__find_start_idx_for_exponential_fit(i)
if start_idx == -1:
self.exponential_fit_vars.append([])
continue
# shorten length of fit to stay in stimulus region if given length is too long
sampling_interval = self.cell_data.get_sampling_interval()
used_length_of_fit = length_of_fit
if (start_idx * sampling_interval) - self.cell_data.get_delay() + length_of_fit > self.cell_data.get_stimulus_end():
print(start_idx * sampling_interval, "start - end", start_idx * sampling_interval + length_of_fit)
print("Shortened length of fit to keep it in the stimulus region!")
used_length_of_fit = self.cell_data.get_stimulus_end() - (start_idx * sampling_interval)
end_idx = start_idx + int(used_length_of_fit/sampling_interval)
y_values = mean_frequencies[i][start_idx:end_idx+1]
x_values = time_axes[i][start_idx:end_idx+1]
tau = self.__approximate_tau_for_exponential_fit(x_values, y_values, i)
# start the actual fit:
try:
p0 = (self.fi_curve.f_zeros[i], tau, self.fi_curve.f_infinities[i])
popt, pcov = curve_fit(fu.exponential_function, x_values, y_values,
p0=p0, maxfev=10000, bounds=([-np.inf, 0, -np.inf], [np.inf, np.inf, np.inf]))
except RuntimeError:
print("RuntimeError happened in fit_exponential.")
self.exponential_fit_vars.append([])
continue
# Obviously a bad fit - time constant, expected in range 3-10ms, has value over 1 second or is negative
if abs(popt[1] > 1) or popt[1] < 0:
self.exponential_fit_vars.append([])
else:
self.exponential_fit_vars.append(popt)
def __approximate_tau_for_exponential_fit(self, x_values, y_values, mean_freq_idx):
if self.fi_curve.f_infinities[mean_freq_idx] < self.fi_curve.f_baselines[mean_freq_idx] * 0.95:
test_val = [y > 0.65 * self.fi_curve.f_infinities[mean_freq_idx] for y in y_values]
else:
test_val = [y < 0.65 * self.fi_curve.f_zeros[mean_freq_idx] for y in y_values]
try:
idx = test_val.index(True)
if idx == 0:
idx = 1
tau = x_values[idx] - x_values[0]
except ValueError:
tau = x_values[-1] - x_values[0]
return tau
def __find_start_idx_for_exponential_fit(self, mean_freq_idx):
stimulus_start_idx = int((self.cell_data.get_delay() + self.cell_data.get_stimulus_start()) / self.cell_data.get_sampling_interval())
if self.fi_curve.f_infinities[mean_freq_idx] > self.fi_curve.f_baselines[mean_freq_idx] * 1.1:
# start setting starting variables for the fit
# search for the start_index by searching for the max
j = 0
while True:
try:
if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j] == self.fi_curve.f_zeros[mean_freq_idx]:
start_idx = stimulus_start_idx + j
break
except IndexError as e:
return -1
j += 1
elif self.fi_curve.f_infinities[mean_freq_idx] < self.fi_curve.f_baselines[mean_freq_idx] * 0.9:
# start setting starting variables for the fit
# search for start by finding the end of the minimum
found_min = False
j = int(0.05 / self.cell_data.get_sampling_interval())
nothing_to_fit = False
while True:
if not found_min:
if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j] == self.fi_curve.f_zeros[mean_freq_idx]:
found_min = True
else:
if self.cell_data.get_mean_isi_frequencies()[mean_freq_idx][stimulus_start_idx + j + 1] > self.fi_curve.f_zeros[mean_freq_idx]:
start_idx = stimulus_start_idx + j
break
if j > 0.1 / self.cell_data.get_sampling_interval():
# no rise in freq until to close to the end of the stimulus (to little place to fit)
return -1
j += 1
if nothing_to_fit:
return -1
else:
# there is nothing to fit to:
return -1
return start_idx
def calculate_tau_from_tau_eff(self):
taus = []
for i in range(len(self.exponential_fit_vars)):
if len(self.exponential_fit_vars[i]) == 0:
continue
tau_eff = self.exponential_fit_vars[i][1]*1000 # tau_eff in ms
# intensity = self.fi_curve.stimulus_value[i]
f_infinity_slope = self.fi_curve.get_f_infinity_slope()
fi_curve_slope = self.fi_curve.get_fi_curve_slope_of_straight()
taus.append(tau_eff*(fi_curve_slope/f_infinity_slope))
# print((fi_curve_slope/f_infinity_slope))
# print(tau_eff*(fi_curve_slope/f_infinity_slope), "=", tau_eff, "*", (fi_curve_slope/f_infinity_slope))
self.tau_real = np.median(taus)
def plot_exponential_fits(self, save_path: str = None, indices: list = None, delete_previous: bool = False):
if delete_previous:
for val in self.cell_data.get_fi_contrasts():
prev_path = save_path + "mean_freq_exp_fit_contrast:" + str(round(val, 3)) + ".png"
if os.path.exists(prev_path):
os.remove(prev_path)
for i in range(len(self.cell_data.get_fi_contrasts())):
if self.exponential_fit_vars[i] == []:
continue
plt.plot(self.cell_data.get_time_axes_mean_frequencies()[i], self.cell_data.get_mean_isi_frequencies()[i])
vars = self.exponential_fit_vars[i]
fit_x = np.arange(0, 0.4, self.cell_data.get_sampling_interval())
plt.plot(fit_x, [fu.exponential_function(x, vars[0], vars[1], vars[2]) for x in fit_x])
plt.ylim([0, max(self.fi_curve.f_zeros[i], self.fi_curve.f_baselines[i])*1.1])
plt.xlabel("Time [s]")
plt.ylabel("Frequency [Hz]")
if save_path is None:
plt.show()
else:
plt.savefig(save_path + "mean_freq_exp_fit_contrast:" + str(round(self.cell_data.get_fi_contrasts()[i], 3)) + ".png")
plt.close()

156
CellData.py Normal file
View File

@ -0,0 +1,156 @@
import DataParserFactory as dpf
from warnings import warn
from os import listdir
import helperFunctions as hf
import numpy as np
def icelldata_of_dir(base_path):
for item in sorted(listdir(base_path)):
item_path = base_path + item
try:
yield CellData(item_path)
except TypeError as e:
warn_msg = str(e)
warn(warn_msg)
class CellData:
# Class to capture all the data of a single cell across all experiments (base rate, FI-curve, .?.)
# should be abstract from the way the data is saved in the background .dat vs .nix
# traces list of lists with traces: [[time], [voltage (v1)], [EOD], [local eod], [stimulus]]
TIME = 0
V1 = 1
EOD = 2
LOCAL_EOD = 3
STIMULUS = 4
def __init__(self, data_path):
self.data_path = data_path
self.base_traces = None
# self.fi_traces = None
self.fi_intensities = None
self.fi_spiketimes = None
self.fi_trans_amplitudes = None
self.mean_isi_frequencies = None
self.time_axes = None
# self.metadata = None
self.parser = dpf.get_parser(data_path)
self.sampling_interval = self.parser.get_sampling_interval()
self.recording_times = self.parser.get_recording_times()
def get_data_path(self):
return self.data_path
def get_base_traces(self, trace_type=None):
if self.base_traces is None:
self.base_traces = self.parser.get_baseline_traces()
if trace_type is None:
return self.base_traces
else:
return self.base_traces[trace_type]
def get_fi_traces(self):
raise NotImplementedError("CellData:get_fi_traces():\n" +
"Getting the Fi-Traces currently overflows the RAM and causes swapping! Reimplement if really needed!")
# if self.fi_traces is None:
# self.fi_traces = self.parser.get_fi_curve_traces()
# return self.fi_traces
def get_fi_spiketimes(self):
self.__read_fi_spiketimes_info__()
return self.fi_spiketimes
def get_fi_intensities(self):
self.__read_fi_spiketimes_info__()
return self.fi_intensities
def get_fi_contrasts(self):
self.__read_fi_spiketimes_info__()
contrast = []
for i in range(len(self.fi_intensities)):
contrast.append((self.fi_intensities[i] - self.fi_trans_amplitudes[i]) / self.fi_trans_amplitudes[i])
return contrast
def get_mean_isi_frequencies(self):
if self.mean_isi_frequencies is None:
self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequencies(self.get_fi_spiketimes(),
self.get_time_start(),
self.get_sampling_interval())
return self.mean_isi_frequencies
def get_time_axes_mean_frequencies(self):
if self.time_axes is None:
self.time_axes, self.mean_isi_frequencies = hf.all_calculate_mean_isi_frequencies(self.get_fi_spiketimes(),
self.get_time_start(),
self.get_sampling_interval())
return self.time_axes
def get_base_frequency(self):
base_freqs = []
for freq in self.get_mean_isi_frequencies():
delay = self.get_delay()
sampling_interval = self.get_sampling_interval()
if delay < 0.1:
warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.")
idx_start = int(0.025 / sampling_interval)
idx_end = int((delay - 0.025) / sampling_interval)
base_freqs.append(np.mean(freq[idx_start:idx_end]))
return np.median(base_freqs)
def get_sampling_interval(self) -> float:
return self.sampling_interval
def get_recording_times(self) -> list:
return self.recording_times
def get_time_start(self) -> float:
return self.recording_times[0]
def get_delay(self) -> float:
return abs(self.recording_times[0])
def get_time_end(self) -> float:
return self.recording_times[2] + self.recording_times[3]
def get_stimulus_start(self) -> float:
return self.recording_times[1]
def get_stimulus_duration(self) -> float:
return self.recording_times[2]
def get_stimulus_end(self) -> float:
return self.get_stimulus_start() + self.get_stimulus_duration()
def get_after_stimulus_duration(self) -> float:
return self.recording_times[3]
def __read_fi_spiketimes_info__(self):
if self.fi_spiketimes is None:
trans_amplitudes, intensities, spiketimes = self.parser.get_fi_curve_spiketimes()
self.fi_intensities, self.fi_spiketimes, self.fi_trans_amplitudes = hf.merge_similar_intensities(intensities, spiketimes, trans_amplitudes)
# def get_metadata(self):
# self.__read_metadata__()
# return self.metadata
#
# def get_metadata_item(self, item):
# self.__read_metadata__()
# if item in self.metadata.keys():
# return self.metadata[item]
# else:
# raise KeyError("CellData:get_metadata_item: Item not found in metadata! - " + str(item))
#
# def __read_metadata__(self):
# if self.metadata is None:
# # TODO!!
# pass

237
DataParserFactory.py Normal file
View File

@ -0,0 +1,237 @@
from os.path import isdir, exists
from warnings import warn
import pyrelacs.DataLoader as Dl
UNKNOWN = -1
DAT_FORMAT = 0
NIX_FORMAT = 1
class AbstractParser:
def cell_get_metadata(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_baseline_traces(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_fi_curve_traces(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_fi_curve_spiketimes(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_sampling_interval(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
def get_recording_times(self):
raise NotImplementedError("NOT YET OVERRIDDEN FROM ABSTRACT CLASS")
class DatParser(AbstractParser):
def __init__(self, dir_path):
self.base_path = dir_path
self.fi_file = self.base_path + "/fispikes1.dat"
self.stimuli_file = self.base_path + "/stimuli.dat"
self.__test_data_file_existence__()
self.fi_recording_times = []
self.sampling_interval = -1
def cell_get_metadata(self):
pass
def get_sampling_interval(self):
if self.sampling_interval == -1:
self.__read_sampling_interval__()
return self.sampling_interval
def get_recording_times(self):
if len(self.fi_recording_times) == 0:
self.__read_fi_recording_times__()
return self.fi_recording_times
def get_baseline_traces(self):
return self.__get_traces__("BaselineActivity")
def get_fi_curve_traces(self):
return self.__get_traces__("FICurve")
# TODO clean up/ rewrite
def get_fi_curve_spiketimes(self):
spiketimes = []
pre_intensities = []
pre_durations = []
intensities = []
trans_amplitudes = []
pre_duration = -1
index = -1
skip = False
trans_amplitude = float('nan')
for metadata, key, data in Dl.iload(self.fi_file):
if len(metadata) != 0:
metadata_index = 0
if '----- Control --------------------------------------------------------' in metadata[0].keys():
metadata_index = 1
pre_duration = float(metadata[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
trans_amplitude = float(metadata[0]["trans. amplitude"][:-2])
if pre_duration == 0:
skip = False
else:
skip = True
continue
if skip:
continue
intensity = float(metadata[metadata_index]['intensity'][:-2])
pre_intensity = float(metadata[metadata_index]['preintensity'][:-2])
intensities.append(intensity)
pre_durations.append(pre_duration)
pre_intensities.append(pre_intensity)
trans_amplitudes.append(trans_amplitude)
spiketimes.append([])
index += 1
if skip:
continue
if data.shape[1] != 1:
raise RuntimeError("DatParser:get_fi_curve_spiketimes():\n read data has more than one dimension!")
spike_time_data = data[:, 0]/1000
if len(spike_time_data) < 10:
continue
if spike_time_data[-1] < 1:
print("# ignoring spike-train that ends before one second.")
continue
spiketimes[index].append(spike_time_data)
# TODO add merging for similar intensities? hf.merge_similar_intensities() + trans_amplitudes
return trans_amplitudes, intensities, spiketimes
def __get_traces__(self, repro):
time_traces = []
v1_traces = []
eod_traces = []
local_eod_traces = []
stimulus_traces = []
nothing = True
for info, key, time, x in Dl.iload_traces(self.base_path, repro=repro):
nothing = False
time_traces.append(time)
v1_traces.append(x[0])
eod_traces.append(x[1])
local_eod_traces.append(x[2])
stimulus_traces.append(x[3])
traces = [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces]
if nothing:
warn_msg = "pyrelacs: iload_traces found nothing for the " + str(repro) + " repro!"
warn(warn_msg)
return traces
def __read_fi_recording_times__(self):
delays = []
stim_duration = []
pause = []
for metadata, key, data in Dl.iload(self.fi_file):
if len(metadata) != 0:
control_key = '----- Control --------------------------------------------------------'
if control_key in metadata[0].keys():
delays.append(float(metadata[0][control_key]["delay"][:-2])/1000)
pause.append(float(metadata[0][control_key]["pause"][:-2])/1000)
stim_key = "----- Test-Intensities -----------------------------------------------"
stim_duration.append(float(metadata[0][stim_key]["duration"][:-2])/1000)
for l in [delays, stim_duration, pause]:
if len(l) == 0:
raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
"Couldn't find any delay, stimulus duration and or pause in the metadata.\n" +
"In file:" + self.base_path)
elif len(set(l)) != 1:
raise RuntimeError("DatParser:__read_fi_recording_times__:\n" +
"Found multiple different delay, stimulus duration and or pause in the metadata.\n" +
"In file:" + self.base_path)
else:
self.fi_recording_times = [-delays[0], 0, stim_duration[0], pause[0] - delays[0]]
def __read_sampling_interval__(self):
stop = False
sampling_intervals = []
for metadata, key, data in Dl.iload(self.stimuli_file):
for md in metadata:
for i in range(4):
key = "sample interval" + str(i+1)
if key in md.keys():
sampling_intervals.append(float(md[key][:-2]) / 1000)
stop = True
else:
break
if stop:
break
if len(sampling_intervals) == 0:
raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
"Sampling intervals not found in stimuli.dat this is not handled!\n" +
"with File:" + self.base_path)
if len(set(sampling_intervals)) != 1:
raise RuntimeError("DatParser:__read_sampling_interval__:\n" +
"Sampling intervals not the same for all traces this is not handled!\n" +
"with File:" + self.base_path)
else:
self.sampling_interval = sampling_intervals[0]
def __test_data_file_existence__(self):
if not exists(self.stimuli_file):
raise RuntimeError(self.stimuli_file + " file doesn't exist!")
if not exists(self.fi_file):
raise RuntimeError(self.fi_file + " file doesn't exist!")
# TODO ####################################
class NixParser(AbstractParser):
def __init__(self, nix_file_path):
self.file_path = nix_file_path
warn("NIX PARSER: NOT YET IMPLEMENTED!")
# TODO ####################################
def get_parser(data_path: str) -> AbstractParser:
data_format = __test_for_format__(data_path)
if data_format == DAT_FORMAT:
return DatParser(data_path)
elif data_format == NIX_FORMAT:
return NixParser(data_path)
elif UNKNOWN:
raise TypeError("DataParserFactory:get_parser(data_path):\nCannot determine type of data for:" + data_path)
def __test_for_format__(data_path):
if isdir(data_path):
if exists(data_path + "/fispikes1.dat"):
return DAT_FORMAT
elif data_path.endswith(".nix"):
return NIX_FORMAT
else:
return UNKNOWN

153
FiCurve.py Normal file
View File

@ -0,0 +1,153 @@
from CellData import CellData
import numpy as np
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from warnings import warn
import functions as fu
class FICurve:
def __init__(self, cell_data: CellData, contrast: bool = True):
self.cell_data = cell_data
self.using_contrast = contrast
if contrast:
self.stimulus_value = cell_data.get_fi_contrasts()
else:
self.stimulus_value = cell_data.get_fi_intensities()
self.f_zeros = []
self.f_infinities = []
self.f_baselines = []
# f_max, f_min, k, x_zero
self.boltzmann_fit_vars = []
# offset increase
self.f_infinity_fit = []
self.all_calculate_frequency_points()
self.fit_line()
self.fit_boltzmann()
def all_calculate_frequency_points(self):
mean_frequencies = self.cell_data.get_mean_isi_frequencies()
if len(mean_frequencies) == 0:
warn("FICurve:all_calculate_frequency_points(): mean_frequencies is empty.\n"
"Was all_calculate_mean_isi_frequencies already called?")
for freq in mean_frequencies:
self.f_zeros.append(self.__calculate_f_zero__(freq))
self.f_baselines.append(self.__calculate_f_baseline__(freq))
self.f_infinities.append(self.__calculate_f_infinity__(freq))
def fit_line(self):
popt, pcov = curve_fit(fu.clipped_line, self.stimulus_value, self.f_infinities)
self.f_infinity_fit = popt
def fit_boltzmann(self):
max_f0 = float(max(self.f_zeros))
min_f0 = float(min(self.f_zeros))
mean_int = float(np.mean(self.stimulus_value))
total_increase = max_f0 - min_f0
total_change_int = max(self.stimulus_value) - min(self.stimulus_value)
start_k = float((total_increase / total_change_int * 4) / max_f0)
popt, pcov = curve_fit(fu.full_boltzmann, self.stimulus_value, self.f_zeros,
p0=(max_f0, min_f0, start_k, mean_int),
maxfev=10000, bounds=([0, 0, -np.inf, -np.inf], [3000, 3000, np.inf, np.inf]))
self.boltzmann_fit_vars = popt
def plot_fi_curve(self, savepath: str = None):
min_x = min(self.stimulus_value)
max_x = max(self.stimulus_value)
step = (max_x - min_x) / 5000
x_values = np.arange(min_x, max_x, step)
plt.plot(self.stimulus_value, self.f_baselines, color='blue', label='f_base')
plt.plot(self.stimulus_value, self.f_infinities, 'o', color='lime', label='f_inf')
plt.plot(x_values, [fu.clipped_line(x, self.f_infinity_fit[0], self.f_infinity_fit[1]) for x in x_values],
color='darkgreen', label='f_inf_fit')
plt.plot(self.stimulus_value, self.f_zeros, 'o', color='orange', label='f_zero')
popt = self.boltzmann_fit_vars
plt.plot(x_values, [fu.full_boltzmann(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values],
color='red', label='f_0_fit')
plt.legend()
plt.ylabel("Frequency [Hz]")
if self.using_contrast:
plt.xlabel("Stimulus contrast")
else:
plt.xlabel("Stimulus intensity [mv]")
if savepath is None:
plt.show()
else:
plt.savefig(savepath + "fi_curve.png")
plt.close()
def __calculate_f_baseline__(self, frequency, buffer=0.025):
delay = self.cell_data.get_delay()
sampling_interval = self.cell_data.get_sampling_interval()
if delay < 0.1:
warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.")
idx_start = int(buffer/sampling_interval)
idx_end = int((delay-buffer)/sampling_interval)
return np.mean(frequency[idx_start:idx_end])
def __calculate_f_zero__(self, frequency, length_of_mean=0.1, buffer=0.025):
stimulus_start = self.cell_data.get_delay() + self.cell_data.get_stimulus_start()
sampling_interval = self.cell_data.get_sampling_interval()
start_idx = int((stimulus_start - buffer) / sampling_interval)
end_idx = int((stimulus_start + buffer*2) / sampling_interval)
freq_before = frequency[start_idx-(int(length_of_mean/sampling_interval)):start_idx]
fb_mean = np.mean(freq_before)
fb_std = np.std(freq_before)
peak_frequency = fb_mean
count = 0
for i in range(start_idx + 1, end_idx):
if fb_mean-3*fb_std <= frequency[i] <= fb_mean+3*fb_std:
continue
if abs(frequency[i] - fb_mean) > abs(peak_frequency - fb_mean):
peak_frequency = frequency[i]
count += 1
return peak_frequency
def __calculate_f_infinity__(self, frequency, length=0.2, buffer=0.025):
stimulus_end_time = \
self.cell_data.get_delay() + self.cell_data.get_stimulus_start() + self.cell_data.get_stimulus_duration()
start_idx = int((stimulus_end_time - length - buffer) / self.cell_data.get_sampling_interval())
end_idx = int((stimulus_end_time - buffer) / self.cell_data.get_sampling_interval())
return np.mean(frequency[start_idx:end_idx])
def get_f_zero_inverse_at_frequency(self, frequency):
b_vars = self.boltzmann_fit_vars
return fu.inverse_full_boltzmann(frequency, b_vars[0], b_vars[1], b_vars[2], b_vars[3])
def get_f_infinity_frequency_at_stimulus_value(self, stimulus_value):
infty_vars = self.f_infinity_fit
return fu.clipped_line(stimulus_value, infty_vars[0], infty_vars[1])
def get_f_infinity_slope(self):
return self.f_infinity_fit[1]
def get_fi_curve_slope_at(self, stimulus_value):
fit_vars = self.boltzmann_fit_vars
return fu.derivative_full_boltzmann(stimulus_value, fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])
def get_fi_curve_slope_of_straight(self):
fit_vars = self.boltzmann_fit_vars
return fu.full_boltzmann_straight_slope(fit_vars[0], fit_vars[1], fit_vars[2], fit_vars[3])

44
functionalityTests.py Normal file
View File

@ -0,0 +1,44 @@
from models.LIFAC import LIFACModel
from stimuli.StepStimulus import StepStimulus
import numpy as np
import matplotlib.pyplot as plt
import functions as fu
def test_lifac():
model = LIFACModel()
stimulus = StepStimulus(0.5, 1, 15)
for step_size in [0.001, 0.1]:
model.set_variable("step_size", step_size)
v, spiketimes = model(stimulus, 2)
plt.plot(np.arange(0, 2, step_size/1000), v, label="step_size:" + str(step_size))
plt.xlabel("time in seconds")
plt.ylabel("Voltage")
plt.title("Voltage in the LIFAC-Model with different step sizes")
plt.show()
plt.close()
def test_plot_inverses(ficurve):
var = ficurve.boltzmann_fit_vars
fig, ax1 = plt.subplots(1, 1, figsize=(4.5, 4.5), tight_layout=True)
start = min(ficurve.stimulus_value)
end = max(ficurve.stimulus_value)
x_values = np.arange(start, end, (end-start)/5000)
ax1.plot(x_values, [fu.full_boltzmann(x, var[0], var[1], var[2], var[3]) for x in x_values], label="fit")
ax1.set_ylabel('freq')
ax1.set_xlabel('stimulus')
start = var[1]
end = var[0]
x_values = np.arange(start, end, (end - start) / 50)
ax1.plot([fu.inverse_full_boltzmann(x, var[0], var[1], var[2], var[3]) for x in x_values], x_values,
'.', c="red", label='inverse')
plt.legend()
plt.show()

40
functions.py Normal file
View File

@ -0,0 +1,40 @@
import numpy as np
def exponential_function(x, a, b, c):
return (a-c)*np.exp(-x/b)+c
def upper_boltzmann(x, f_max, k, x_zero):
return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None)
def full_boltzmann(x, f_max, f_min, k, x_zero):
return (f_max-f_min) * (1 / (1 + np.power(np.e, -k * (x - x_zero)))) + f_min
def full_boltzmann_straight_slope(f_max, f_min, k, x_zero=0):
return (f_max-f_min)*k*1/2
def derivative_full_boltzmann(x, f_max, f_min, k, x_zero):
return (f_max - f_min) * k * np.power(np.e, -k * (x - x_zero)) / (1 + np.power(np.e, -k * (x - x_zero))**2)
def inverse_full_boltzmann(x, f_max, f_min, k, x_zero):
if x < f_min or x > f_max:
raise ValueError("Value undefined in inverse_full_boltzmann")
return -(np.log((f_max-f_min) / (x - f_min) - 1) / k) + x_zero
def clipped_line(x, a, b):
return np.clip(a+b*x, 0, None)
def inverse_clipped_line(x, a, b):
if clipped_line(x, a, b) == 0:
raise ValueError("Value undefined in inverse_clipped_line.")
return (x-a)/b

54
generalTests.py Normal file
View File

@ -0,0 +1,54 @@
import numpy as np
import matplotlib.pyplot as plt
from models.LeakyIntegrateFireModel import LIFModel
# def calculate_step(current_v, tau, i_b, step_size=0.01):
# return current_v + (step_size * (-current_v + mem_res * i_b)) / tau
def function_e(x):
return (0-15) * np.e**(-x/1) + 15
# x_values = np.arange(0, 5, 0.01)
# plt.plot(x_values, [function_e(x) for x in x_values])
# plt.show()
# def function_f(i_base, tau=1, threshold=10, reset=0):
# return -1/(tau*np.log((threshold-i_base)/(reset - i_base)))
#
# x_values = np.arange(0, 20, 0.001)
# plt.plot(x_values, [function_f(x) for x in x_values])
# plt.show()
# LIF test:
# Rm = 100 MOhm, Cm = 200pF
step_size = 0.01 # ms
mem_res = 100*1000000
tau = 1
base_freq = 30
v_threshold = 10
base_input = -(- v_threshold / (np.e**(-1/(base_freq*tau))) + 1) / mem_res
stim1 = int(1000/step_size) * [base_input]
stimulus = []
stimulus.extend(stim1)
lif = LIFModel(mem_res, tau, 0, 0, stimulus, 10)
voltage, spikes_b = lif.calculate_response()
y_spikes = []
x_spikes = []
for i in range(len(spikes_b)):
if spikes_b[i]:
y_spikes.append(10.5)
x_spikes.append(i*step_size)
time = np.arange(0, 1000, step_size)
plt.plot(time, voltage)
plt.plot(x_spikes, y_spikes, 'o')
plt.show()
plt.close()

214
helperFunctions.py Normal file
View File

@ -0,0 +1,214 @@
import os
import pyrelacs.DataLoader as dl
import numpy as np
import matplotlib.pyplot as plt
from warnings import warn
def get_subfolder_paths(basepath):
subfolders = []
for content in os.listdir(basepath):
content_path = basepath + content
if os.path.isdir(content_path):
subfolders.append(content_path)
return sorted(subfolders)
def get_traces(directory, trace_type, repro):
# trace_type = 1: Voltage p-unit
# trace_type = 2: EOD
# trace_type = 3: local EOD ~(EOD + stimulus)
# trace_type = 4: Stimulus
load_iter = dl.iload_traces(directory, repro=repro)
time_traces = []
value_traces = []
nothing = True
for info, key, time, x in load_iter:
nothing = False
time_traces.append(time)
value_traces.append(x[trace_type-1])
if nothing:
print("iload_traces found nothing for the BaselineActivity repro!")
return time_traces, value_traces
def get_all_traces(directory, repro):
load_iter = dl.iload_traces(directory, repro=repro)
time_traces = []
v1_traces = []
eod_traces = []
local_eod_traces = []
stimulus_traces = []
nothing = True
for info, key, time, x in load_iter:
nothing = False
time_traces.append(time)
v1_traces.append(x[0])
eod_traces.append(x[1])
local_eod_traces.append(x[2])
stimulus_traces.append(x[3])
print(info)
traces = [v1_traces, eod_traces, local_eod_traces, stimulus_traces]
if nothing:
print("iload_traces found nothing for the BaselineActivity repro!")
return time_traces, traces
def merge_similar_intensities(intensities, spiketimes, trans_amplitudes):
i = 0
diffs = np.diff(sorted(intensities))
margin = np.mean(diffs) * 0.6666
while True:
if i >= len(intensities):
break
intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin)
i += 1
# Sort the lists so that intensities are increasing
x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
intensities = x[0]
spiketimes = x[1]
return intensities, spiketimes, trans_amplitudes
def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin):
intensity = intensities[index]
indices_to_merge = []
for i in range(index+1, len(intensities)):
if np.abs(intensities[i]-intensity) < margin:
indices_to_merge.append(i)
if len(indices_to_merge) != 0:
indices_to_merge.reverse()
trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge]
all_the_same = True
for j in range(1, len(trans_amplitude_values)):
if not trans_amplitude_values[0] == trans_amplitude_values[j]:
all_the_same = False
break
if all_the_same:
for idx in indices_to_merge:
del trans_amplitudes[idx]
else:
raise RuntimeError("Trans_amplitudes not the same....")
for idx in indices_to_merge:
spiketimes[index].extend(spiketimes[idx])
del spiketimes[idx]
del intensities[idx]
return intensities, spiketimes, trans_amplitudes
def all_calculate_mean_isi_frequencies(spiketimes, time_start, sampling_interval):
times = []
mean_frequencies = []
for i in range(len(spiketimes)):
trial_times = []
trial_means = []
for j in range(len(spiketimes[i])):
time, isi_freq = calculate_isi_frequency(spiketimes[i][j], time_start, sampling_interval)
trial_means.append(isi_freq)
trial_times.append(time)
time, mean_freq = calculate_mean_frequency(trial_times, trial_means)
times.append(time)
mean_frequencies.append(mean_freq)
return times, mean_frequencies
def calculate_isi_frequency(spiketimes, time_start, sampling_interval):
first_isi = spiketimes[0] - time_start
isis = [first_isi]
isis.extend(np.diff(spiketimes))
time = np.arange(time_start, spiketimes[-1], sampling_interval)
full_frequency = []
i = 0
for isi in isis:
if isi == 0:
warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()")
continue
freq = 1 / isi
frequency_step = int(round(isi * (1 / sampling_interval))) * [freq]
full_frequency.extend(frequency_step)
i += 1
if len(full_frequency) != len(time):
if abs(len(full_frequency) - len(time)) == 1:
warn("FiCurve:__calculate_mean_isi_frequency__():\nFrequency and time were one of in length!")
if len(full_frequency) < len(time):
time = time[:len(full_frequency)]
else:
full_frequency = full_frequency[:len(time)]
else:
print("ERROR PRINT:")
print("freq:", len(full_frequency), "time:", len(time), "diff:", len(full_frequency) - len(time))
raise RuntimeError("FiCurve:__calculate_mean_isi_frequency__():\n"
"Frequency and time are not the same length!")
return time, full_frequency
def calculate_mean_frequency(trial_times, trial_freqs):
lengths = [len(t) for t in trial_times]
shortest = min(lengths)
time = trial_times[0][0:shortest]
shortend_freqs = [freq[0:shortest] for freq in trial_freqs]
mean_freq = [sum(e) / len(e) for e in zip(*shortend_freqs)]
return time, mean_freq
def crappy_smoothing(signal:list, window_size:int = 5) -> list:
smoothed = []
for i in range(len(signal)):
k = window_size
if i < window_size:
k = i
j = window_size
if i + j > len(signal):
j = len(signal) - i
smoothed.append(np.mean(signal[i-k:i+j]))
return smoothed
def plot_frequency_curve(cell_data, save_path: str = None, indices: list = None):
contrast = cell_data.get_fi_contrasts()
time_axes = cell_data.get_time_axes_mean_frequencies()
mean_freqs = cell_data.get_mean_isi_frequencies()
if indices is None:
indices = np.arange(len(contrast))
for i in indices:
plt.plot(time_axes[i], mean_freqs[i], label=str(round(contrast[i], 2)))
if save_path is None:
plt.show()
else:
plt.savefig(save_path + "mean_frequency_curves.png")
plt.close()

View File

@ -0,0 +1,283 @@
import pyrelacs.DataLoader as dl
import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
import os
import helperFunctions as hf
from thunderfish.eventdetection import detect_peaks
SAVEPATH = ""
def get_savepath():
global SAVEPATH
return SAVEPATH
def set_savepath(new_path):
global SAVEPATH
SAVEPATH = new_path
def main():
for folder in hf.get_subfolder_paths("data/"):
filepath = folder + "/basespikes1.dat"
set_savepath("figures/" + folder.split('/')[1] + "/")
print("Folder:", folder)
if not os.path.exists(get_savepath()):
os.makedirs(get_savepath())
spiketimes = []
ran = False
for metadata, key, data in dl.iload(filepath):
ran = True
spikes = data[:, 0]
spiketimes.append(spikes) # save for calculation of vector strength
metadata = metadata[0]
#print(metadata)
# print('firing frequency1:', metadata['firing frequency1'])
# print(mean_firing_rate(spikes))
# print('Coefficient of Variation (CV):', metadata['CV1'])
# print(calculate_coefficient_of_variation(spikes))
if not ran:
print("------------ DIDN'T RUN")
isi_histogram(spiketimes)
times, eods = hf.get_traces(folder, 2, 'BaselineActivity')
times, v1s = hf.get_traces(folder, 1, 'BaselineActivity')
vs = calculate_vector_strength(times, eods, spiketimes, v1s)
# print("Calculated vector strength:", vs)
def mean_firing_rate(spiketimes):
# mean firing rate (number of spikes per time)
return len(spiketimes)/spiketimes[-1]*1000
def calculate_coefficient_of_variation(spiketimes):
# CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
isi = np.diff(spiketimes)
std = np.std(isi)
mean = np.mean(isi)
return std/mean
def isi_histogram(spiketimes):
# ISI histogram (play around with binsize! < 1ms)
isi = []
for spike_list in spiketimes:
isi.extend(np.diff(spike_list))
maximum = max(isi)
bins = np.arange(0, maximum*1.01, 0.1)
plt.title('Phase locking of ISI without stimulus')
plt.xlabel('ISI in ms')
plt.ylabel('Count')
plt.hist(isi, bins=bins)
plt.savefig(get_savepath() + 'phase_locking_without_stimulus.png')
plt.close()
def calculate_vector_strength(times, eods, spiketimes, v1s):
# Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8
# dl.iload_traces(repro='BaselineActivity')
relative_spike_times = []
eod_durations = []
if len(times) == 0:
print("-----LENGTH OF TIMES = 0")
for recording in range(len(times)):
rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketimes[recording])
relative_spike_times.extend(rel_spikes)
eod_durations.extend(eod_durs)
vs = __vector_strength__(rel_spikes, eod_durs)
phases = calculate_phases(rel_spikes, eod_durs)
plot_polar(phases, "test_phase_locking_" + str(recording) + "_with_vs:" + str(round(vs, 3)) + ".png")
print("VS of recording", recording, ":", vs)
plot_phaselocking_testfigures(times[recording], eods[recording], spiketimes[recording], v1s[recording])
return __vector_strength__(relative_spike_times, eod_durations)
def eods_around_spikes(time, eod, spiketimes):
eod_durations = []
relative_spike_times = []
for spike in spiketimes:
index = spike * 20 # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx
if index != np.round(index):
print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index)
continue
index = int(index)
start_time, end_time = search_eod_start_and_end_times(time, eod, index)
eod_durations.append(end_time-start_time)
relative_spike_times.append(spike/1000 - start_time)
return relative_spike_times, eod_durations
def search_eod_start_and_end_times(time, eod, index):
# TODO might break if a spike is in the cut off first or last eod!
# search start_time:
previous = index
working_idx = index-1
while True:
if eod[working_idx] < 0 < eod[previous]:
first_value = eod[working_idx]
second_value = eod[previous]
dif = second_value - first_value
part = np.abs(first_value/dif)
time_dif = np.abs(time[previous] - time[working_idx])
start_time = time[working_idx] + time_dif*part
break
previous = working_idx
working_idx -= 1
# search end_time
previous = index
working_idx = index + 1
while True:
if eod[previous] < 0 < eod[working_idx]:
first_value = eod[previous]
second_value = eod[working_idx]
dif = second_value - first_value
part = np.abs(first_value / dif)
time_dif = np.abs(time[previous] - time[working_idx])
end_time = time[working_idx] + time_dif * part
break
previous = working_idx
working_idx += 1
return start_time, end_time
def search_closest_index(array, value, start=0, end=-1):
# searches the array to find the closest value in the array to the given value and returns its index.
# expects sorted array!
# start hast to be smaller than end
if end == -1:
end = len(array)-1
while True:
if end-start <= 1:
return end if np.abs(array[end]-value) < np.abs(array[start]-value) else start
middle = int(np.floor((end-start)/2)+start)
if array[middle] == value:
return middle
elif array[middle] > value:
end = middle
continue
else:
start = middle
continue
def __vector_strength__(relative_spike_times, eod_durations):
# adapted from Ramona
n = len(relative_spike_times)
if n == 0:
return 0
phase_times = np.zeros(n)
for i in range(n):
phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi
vs = np.sqrt((1 / n * sum(np.cos(phase_times))) ** 2 + (1 / n * sum(np.sin(phase_times))) ** 2)
return vs
def calculate_phases(relative_spike_times, eod_durations):
phase_times = np.zeros(len(relative_spike_times))
for i in range(len(relative_spike_times)):
phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi
return phase_times
def plot_polar(phases, name=""):
fig = plt.figure()
ax = fig.add_subplot(111, polar=True)
# r = np.arange(0, 1, 0.001)
# theta = 2 * 2 * np.pi * r
# line, = ax.plot(theta, r, color='#ee8d18', lw=3)
bins = np.arange(0, np.pi*2, 0.05)
ax.hist(phases, bins=bins)
if name == "":
plt.show()
else:
plt.savefig(get_savepath() + name)
plt.close()
def plot_phaselocking_testfigures(time, eod, spiketimes, v1):
eod_start_times = []
eod_end_times = []
for spike in spiketimes:
index = spike * 20 # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx
if index != np.round(index):
print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index)
continue
index = int(index)
start_time, end_time = search_eod_start_and_end_times(time, eod, index)
eod_start_times.append(start_time)
eod_end_times.append(end_time)
cutoff_in_sec = 2
sampling = 20000
max_idx = cutoff_in_sec*sampling
spikes_part = [x/1000 for x in spiketimes if x/1000 < cutoff_in_sec]
count_spikes = len(spikes_part)
print(spiketimes)
print(len(spikes_part))
x_axis = time[0:max_idx]
plt.plot(spikes_part, np.ones(len(spikes_part))*-20, 'o')
plt.plot(x_axis, v1[0:max_idx])
plt.plot(eod_start_times[: count_spikes], np.zeros(count_spikes), 'o')
plt.plot(eod_end_times[: count_spikes], np.zeros(count_spikes), 'o')
plt.show()
plt.close()
if __name__ == '__main__':
main()

View File

@ -0,0 +1,298 @@
import numpy as np
import matplotlib.pyplot as plt
import pyrelacs.DataLoader as dl
import os
import helperFunctions as hf
from IPython import embed
from scipy.optimize import curve_fit
import warnings
SAMPLING_INTERVAL = 1/20000
STIMULUS_START = 0
STIMULUS_DURATION = 0.400
PRE_DURATION = 0.250
TOTAL_DURATION = 1.25
def main():
for folder in hf.get_subfolder_paths("data/"):
filepath = folder + "/fispikes1.dat"
set_savepath("figures/" + folder.split('/')[1] + "/")
print("Folder:", folder)
if not os.path.exists(get_savepath()):
os.makedirs(get_savepath())
spiketimes = []
intensities = []
index = -1
for metadata, key, data in dl.iload(filepath):
# embed()
if len(metadata) != 0:
metadata_index = 0
if '----- Control --------------------------------------------------------' in metadata[0].keys():
metadata_index = 1
print(metadata)
i = float(metadata[metadata_index]['intensity'][:-2])
intensities.append(i)
spiketimes.append([])
index += 1
spiketimes[index].append(data[:, 0]/1000)
intensities, spiketimes = hf.merge_similar_intensities(intensities, spiketimes)
# Sort the lists so that intensities are increasing
x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
intensities = x[0]
spiketimes = x[1]
mean_frequencies = calculate_mean_frequencies(intensities, spiketimes)
popt, pcov = fit_exponential(intensities, mean_frequencies)
plot_frequency_curve(intensities, mean_frequencies)
f_baseline = calculate_f_baseline(mean_frequencies)
f_infinity = calculate_f_infinity(mean_frequencies)
f_zero = calculate_f_zero(mean_frequencies)
# plot_fi_curve(intensities, f_baseline, f_zero, f_infinity)
# TODO !!
def fit_exponential(intensities, mean_frequencies):
start_idx = int((PRE_DURATION + STIMULUS_START+0.005) / SAMPLING_INTERVAL)
end_idx = int((PRE_DURATION + STIMULUS_START + 0.1) / SAMPLING_INTERVAL)
time_constants = []
#print(start_idx, end_idx)
popts = []
pcovs = []
for i in range(len(mean_frequencies)):
freq = mean_frequencies[i]
y_values = freq[start_idx:end_idx+1]
x_values = np.arange(start_idx*SAMPLING_INTERVAL, end_idx*SAMPLING_INTERVAL, SAMPLING_INTERVAL)
try:
popt, pcov = curve_fit(exponential_function, x_values, y_values, p0=(1/(np.power(1, 10)), .5, 50, 180), maxfev=10000)
except RuntimeError:
print("RuntimeError happened in fit_exponential.")
continue
#print(popt)
#print(pcov)
#print()
popts.append(popt)
pcovs.append(pcov)
plt.plot(np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL), freq)
plt.plot(x_values-PRE_DURATION, [exponential_function(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values])
# plt.show()
save_path = get_savepath() + "exponential_fits/"
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.savefig(save_path + "fit_intensity:" + str(round(intensities[i], 4)) + ".png")
plt.close()
return popts, pcovs
def calculate_mean_frequency(freqs):
mean_freq = [sum(e) / len(e) for e in zip(*freqs)]
return mean_freq
def gaussian_kernel(sigma, dt):
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
def calculate_kernel_frequency(spiketimes, time, sampling_interval):
sp = spiketimes
t = time # Probably goes from -200ms to some amount of ms in the positive ~1200?
dt = sampling_interval
kernel_width = 0.01 # kernel width is a time in seconds how sharp the frequency should be counted
binary = np.zeros(t.shape)
spike_indices = ((sp - t[0]) / dt).astype(int)
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
g = gaussian_kernel(kernel_width, dt)
rate = np.convolve(binary, g, mode='same')
return rate
def calculate_isi_frequency(spiketimes, time):
first_isi = spiketimes[0] - (-PRE_DURATION) # diff to the start at 0
last_isi = TOTAL_DURATION - spiketimes[-1] # diff from the last spike to the end of time :D
isis = [first_isi]
isis.extend(np.diff(spiketimes))
isis.append(last_isi)
if np.isnan(first_isi):
print(spiketimes[:10])
print(isis[0:10])
quit()
rate = []
for isi in isis:
if isi == 0:
print("probably a problem")
isi = 0.0000000001
freq = 1/isi
frequency_step = int(round(isi*(1/SAMPLING_INTERVAL)))*[freq]
rate.extend(frequency_step)
#plt.plot((np.arange(len(rate))-PRE_DURATION)/(1/SAMPLING_INTERVAL), rate)
#plt.plot([sum(isis[:i+1]) for i in range(len(isis))], [200 for i in isis], 'o')
#plt.plot(time, [100 for t in time])
#plt.show()
if len(rate) != len(time):
if "12-13-af" in get_savepath():
warnings.warn("preStimulus duration > 0 still not supported")
return [1]*len(time)
else:
print(len(rate), len(time), len(rate) - len(time))
print(rate)
print(isis)
print("Quitting because time and rate aren't the same length")
quit()
return rate
def calculate_mean_frequencies(intensities, spiketimes):
time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL)
mean_frequencies = []
for i in range(len(intensities)):
freqs = []
for spikes in spiketimes[i]:
if len(spikes) < 2:
continue
freq = calculate_isi_frequency(spikes, time)
freqs.append(freq)
mf = calculate_mean_frequency(freqs)
mean_frequencies.append(mf)
return mean_frequencies
def calculate_f_baseline(mean_frequencies):
buffer_time = 0.05
start_idx = int(0.05/SAMPLING_INTERVAL)
end_idx = int((PRE_DURATION - STIMULUS_START - buffer_time)/SAMPLING_INTERVAL)
f_zeros = []
for freq in mean_frequencies:
f_0 = np.mean(freq[start_idx:end_idx])
f_zeros.append(f_0)
return f_zeros
def calculate_f_infinity(mean_frequencies):
buffer_time = 0.05
start_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - 0.15 - buffer_time) / SAMPLING_INTERVAL)
end_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - buffer_time) / SAMPLING_INTERVAL)
f_infinity = []
for freq in mean_frequencies:
f_inf = np.mean(freq[start_idx:end_idx])
f_infinity.append(f_inf)
return f_infinity
def calculate_f_zero(mean_frequencies):
buffer_time = 0.1
start_idx = int((PRE_DURATION + STIMULUS_START - buffer_time) / SAMPLING_INTERVAL)
end_idx = int((PRE_DURATION + STIMULUS_START + buffer_time) / SAMPLING_INTERVAL)
f_peaks = []
for freq in mean_frequencies:
fp = np.mean(freq[start_idx-500:start_idx])
for i in range(start_idx+1, end_idx):
if abs(freq[i] - freq[start_idx]) > abs(fp - freq[start_idx]):
fp = freq[i]
f_peaks.append(fp)
return f_peaks
def plot_fi_curve(intensities, f_baseline, f_zero, f_infinity):
plt.plot(intensities, f_baseline, label="f_baseline")
plt.plot(intensities, f_zero, 'o', label="f_zero")
plt.plot(intensities, f_infinity, label="f_infinity")
max_f0 = float(max(f_zero))
mean_int = float(np.mean(intensities))
start_k = float(((f_zero[-1] - f_zero[0]) / (intensities[-1] - intensities[0])*4)/f_zero[-1])
popt, pcov = curve_fit(fill_boltzmann, intensities, f_zero, p0=(max_f0, start_k, mean_int), maxfev=10000)
print(popt)
min_x = min(intensities)
max_x = max(intensities)
step = (max_x - min_x) / 5000
x_values_boltzmann_fit = np.arange(min_x, max_x, step)
plt.plot(x_values_boltzmann_fit, [fill_boltzmann(i, popt[0], popt[1], popt[2]) for i in x_values_boltzmann_fit], label='fit')
plt.title("FI-Curve")
plt.ylabel("Frequency in Hz")
plt.xlabel("Intensity in mV")
plt.legend()
# plt.show()
plt.savefig(get_savepath() + "fi_curve.png")
plt.close()
def plot_frequency_curve(intensities, mean_frequencies):
colors = ["red", "green", "blue", "violet", "orange", "grey"]
time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL)
for i in range(len(intensities)):
plt.plot(time, mean_frequencies[i], color=colors[i % 6], label=str(intensities[i]))
plt.plot((0, 0), (0, 500), color="black")
plt.plot((0.4, 0.4), (0, 500), color="black")
plt.legend()
plt.xlabel("Time in seconds")
plt.ylabel("Frequency in Hz")
plt.title("Frequency curve")
plt.savefig(get_savepath() + "mean_frequency_curves.png")
plt.close()
def exponential_function(x, a, b, c, d):
return a*np.exp(-c*(x-b))+d
def upper_boltzmann(x, f_max, k, x_zero):
return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None)
def fill_boltzmann(x, f_max, k, x_zero):
return f_max * (1 / (1 + np.power(np.e, -k * (x - x_zero))))
SAVEPATH = ""
def get_savepath():
global SAVEPATH
return SAVEPATH
def set_savepath(new_path):
global SAVEPATH
SAVEPATH = new_path
if __name__ == '__main__':
main()

View File

@ -0,0 +1,20 @@
import pyrelacs.DataLoader as dl
for metadata, key, data in dl.iload('2012-06-27-ah-invivo-1/basespikes1.dat'):
print(data.shape)
break
# mean firing rate (number of spikes per time)
# CV (stdev of ISI divided by mean ISI (np.diff(spiketimes))
# ISI histogram (play around with binsize! < 1ms)
# Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8
# dl.iload_traces(repro='BaselineActivity')
def test():
for metadata, key, data in dl.iload('data/2012-06-27-ah-invivo-1/basespikes1.dat'):
print(data.shape)
for i in metadata:
for key in i.keys():
print(key, ":", i[key])
break

61
main.py Normal file
View File

@ -0,0 +1,61 @@
from FiCurve import FICurve
from CellData import icelldata_of_dir
import os
import helperFunctions as hf
from AdaptionCurrent import Adaption
from models.NeuronModel import NeuronModel
from functionalityTests import *
# TODO command line interface needed/nice ?
def main():
run_tests()
quit()
for cell_data in icelldata_of_dir("./data/"):
print()
print(cell_data.get_data_path())
model = NeuronModel(cell_data)
x_values = np.arange(0, 1000, 0.01)
stimulus = [0]*int(200/0.01)
stimulus.extend([0.19]*int(400/0.01))
stimulus.extend([0]*int(400/0.01))
v, spikes = model.simulate(0, 1000, stimulus)
# plt.plot(x_values, v)
spikes = [s/1000 for s in spikes]
time, freq = hf.calculate_isi_frequency(spikes, 0, 0.01/1000)
plt.plot(time, freq)
plt.show()
quit()
continue
figures_save_path = "./figures/" + os.path.basename(cell_data.get_data_path()) + "/"
ficurve = FICurve(cell_data)
ficurve.plot_fi_curve(figures_save_path)
adaption = Adaption(cell_data, ficurve)
adaption.plot_exponential_fits(figures_save_path + "exponential_fits/", delete_previous=True)
for i in range(len(adaption.exponential_fit_vars)):
if len(adaption.exponential_fit_vars[i]) == 0:
continue
tau = round(adaption.exponential_fit_vars[i][1]*1000, 2)
contrast = round(ficurve.stimulus_value[i], 3)
# print(tau, "ms - tau_eff at", contrast, "contrast")
# test_plot_inverses(ficurve)
print("Chosen tau [ms]:", adaption.tau_real)
def run_tests():
test_lifac()
if __name__ == '__main__':
main()

88
models/LIFAC.py Normal file
View File

@ -0,0 +1,88 @@
from stimuli.AbstractStimulus import AbstractStimulus
import numpy as np
class LIFACModel:
# all times in milliseconds
KEYS = ["mem_res", "mem_tau", "v_base", "v_zero", "threshold", "step_size", "delta_a", "tau_a"]
VALUES = [100 * 1000000, 0.1 * 200, 0, 0, 10, 0.01, 1, 200]
# membrane time constant tau = mem_cap*mem_res
def __init__(self, params: dict = None):
self.parameters = {}
if params is None:
self._set_default_parameters()
else:
self._test_given_parameters(params)
self.set_parameters(params)
self.last_v = []
self.last_adaption = []
self.last_spiketimes = []
def __call__(self, stimulus: AbstractStimulus, total_time_s):
output_voltage = []
adaption = []
spiketimes = []
current_v = self.parameters["v_zero"]
current_a = 0
for time_point in np.arange(0, total_time_s*1000, self.parameters["step_size"]):
v_next = self._calculate_voltage_step(current_v, stimulus.value_at_time_in_ms(time_point) - current_a)
a_next = self._calculate_adaption_step(current_a)
if v_next > self.parameters["threshold"]:
v_next = self.parameters["v_base"]
spiketimes.append(time_point/1000)
a_next += self.parameters["delta_a"]
output_voltage.append(v_next)
adaption.append(a_next)
current_v = v_next
current_a = a_next
self.last_v = output_voltage
self.last_adaption = adaption
self.last_spiketimes = spiketimes
return output_voltage, spiketimes
def _calculate_voltage_step(self, current_v, input_v):
v_base = self.parameters["v_base"]
step_size = self.parameters["step_size"]
# mem_res = self.parameters["mem_res"]
mem_tau = self.parameters["mem_tau"]
return current_v + (step_size * (v_base - current_v + input_v)) / mem_tau
def _calculate_adaption_step(self, current_a):
step_size = self.parameters["step_size"]
return current_a + (step_size * (-current_a)) / self.parameters["tau_a"]
def set_parameters(self, params):
for k in params.keys():
self.parameters[k] = params[k]
for i in range(len(self.KEYS)):
if self.KEYS[i] not in self.parameters.keys():
self.parameters[self.KEYS[i]] = self.VALUES[i]
def get_parameters(self):
return self.parameters
def set_variable(self, key, value):
if key not in self.KEYS:
raise ValueError("Given key is unknown!\n"
"Please check spelling and refer to list LIFAC.KEYS.")
self.parameters[key] = value
def _set_default_parameters(self):
for i in range(len(self.KEYS)):
self.parameters[self.KEYS[i]] = self.VALUES[i]
def _test_given_parameters(self, params):
for k in params.keys():
if k not in self.KEYS:
err_msg = "Unknown key in the given parameters:" + str(k)
raise ValueError(err_msg)

View File

@ -0,0 +1,37 @@
class LIFModel:
# all times in milliseconds
def __init__(self, mem_res, mem_tau, v_base, v_zero, input_current, threshold, input_offset=0, step_size=0.01):
self.mem_res = mem_res
# self.membrane_capacitance = mem_cap
self.mem_tau = mem_tau # membrane time constant tau = mem_cap*mem_res
self.v_base = v_base
self.v_zero = v_zero
self.threshold = threshold
self.step_size = step_size
self.input_current = input_current
self.input_offset = input_offset
def calculate_response(self):
output_voltage = [self.v_zero]
spikes = []
for idx in range(1, len(self.input_current)):
v_next = self.__calculate_next_step__(output_voltage[idx-1], self.input_current[idx-1])
if v_next > self.threshold:
v_next = self.v_base
spikes.append(True)
else:
spikes.append(False)
output_voltage.append(v_next)
return output_voltage, spikes
def set_input_current(self, input_current, offset=0):
self.input_current = input_current
self.input_offset = offset
def __calculate_next_step__(self, current_v, input_i):
return current_v + (self.step_size * (self.v_base - current_v + self.mem_res * input_i)) / self.mem_tau

110
models/NeuronModel.py Normal file
View File

@ -0,0 +1,110 @@
from CellData import CellData
from FiCurve import FICurve
from AdaptionCurrent import Adaption
import numpy as np
import matplotlib.pyplot as plt
class NeuronModel:
KEYS = ["mem_res", "mem_tau", "v_base", "v_zero", "threshold", "step_size"]
VALUES = [100 * 1000000, 0.1 * 200, 0, 0, 10, 0.01]
def __init__(self, cell_data: CellData, variables: dict = None):
self.cell_data = cell_data
self.fi_curve = FICurve(cell_data)
self.adaption = Adaption(cell_data, self.fi_curve)
if variables is not None:
self._test_given_variables(variables)
self.variables = variables
else:
self.variables = {}
self._add_standard_variables()
def __call__(self, stimulus):
raise NotImplementedError("Soon. sorry!")
def _approximate_variables_from_data(self):
# TODO don't return but save in class in some form! approximate/calculate other variables?
base_input = self._calculate_input_fro_base_frequency()
return base_input
def simulate(self, start_v, time_in_ms, stimulus):
response = []
spikes = []
current_v = start_v
current_a = 0
base_input = self._calculate_input_fro_base_frequency()
adaption_values = []
a_infties = []
print("base input:", base_input)
for time_step in np.arange(0, time_in_ms, self.variables["step_size"]):
stimulus_input = stimulus[int(time_step/self.variables["step_size"])] - current_a
new_v = self._calculate_next_step(current_v, current_a*base_input, base_input + base_input*stimulus_input)
new_a, a_infty = self._calculate_adaption_step(current_a, stimulus_input)
if new_v > self.variables["threshold"]:
new_v = self.variables["v_base"]
spikes.append(time_step)
response.append(new_v)
adaption_values.append(current_a)
a_infties.append(a_infty)
current_v = new_v
current_a = new_a
plt.title("Adaption variable")
plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), np.array(adaption_values), label="adaption")
plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), np.array(a_infties), label="a_inf")
plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), stimulus, label="stimulus")
plt.legend()
plt.xlabel("time in ms")
plt.ylabel("value as contrast?")
plt.show()
plt.close()
return response, spikes
def _calculate_next_step(self, current_v, current_a, input_v):
step_size = self.variables["step_size"]
v_base = self.variables["v_base"]
mem_tau = self.variables["mem_tau"]
return current_v + (step_size * (- current_v + v_base + input_v - current_a)) / mem_tau
def _calculate_adaption_step(self, current_a, stimulus_input):
step_size = self.variables["step_size"]
tau_a = self.adaption.tau_real
f_infty_freq = self.fi_curve.get_f_infinity_frequency_at_stimulus_value(stimulus_input)
a_infinity = stimulus_input - self.fi_curve.get_f_zero_inverse_at_frequency(f_infty_freq)
return current_a + (step_size * (- current_a + a_infinity)) / tau_a, a_infinity
def set_variable(self, key, value):
if key not in self.KEYS:
raise ValueError("Given key is unknown!\n"
"Please check spelling and refer to list NeuronModel.KEYS.")
self.variables[key] = value
def set_variables(self, variables: dict):
self._test_given_variables(variables)
for k in variables.keys():
self.variables[k] = variables[k]
def _calculate_input_fro_base_frequency(self):
return - self.variables["threshold"] / (
np.e ** (-1 / (self.cell_data.get_base_frequency()/1000 * self.variables["mem_tau"])) - 1)
def _test_given_variables(self, variables: dict):
for k in variables.keys():
if k not in self.KEYS:
raise ValueError("Unknown key in given model variables. \n"
"Please check spelling and refer to list NeuronModel.KEYS.")
def _add_standard_variables(self):
for i in range(len(self.KEYS)):
if self.KEYS[i] not in self.variables:
self.variables[self.KEYS[i]] = self.VALUES[i]

0
models/__init__.py Normal file
View File

View File

@ -0,0 +1,8 @@
class AbstractStimulus:
def value_at_time_in_ms(self, time_point):
raise NotImplementedError("This is an abstract class!")
def value_at_time_in_s(self, time_point):
raise NotImplementedError("This is an abstract class!")

26
stimuli/StepStimulus.py Normal file
View File

@ -0,0 +1,26 @@
from stimuli.AbstractStimulus import AbstractStimulus
class StepStimulus(AbstractStimulus):
def __init__(self, start, duration, value, base_value=0, seconds=True):
self.start = 0
self.duration = 0
self.base_value = base_value
self.value = value
if seconds:
self.start = start
self.duration = duration
else:
self.start = start / 1000
self.duration = duration / 1000
def value_at_time_in_ms(self, time_point):
return self.value_at_time_in_s(time_point/1000)
def value_at_time_in_s(self, time_point):
if self.start <= time_point <= self.start + self.duration:
return self.value
else:
return self.base_value

0
stimuli/__init__.py Normal file
View File