P-unit_model/introduction/introductionFICurve.py
2019-12-20 13:33:34 +01:00

299 lines
9.3 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import pyrelacs.DataLoader as dl
import os
import helperFunctions as hf
from IPython import embed
from scipy.optimize import curve_fit
import warnings
SAMPLING_INTERVAL = 1/20000
STIMULUS_START = 0
STIMULUS_DURATION = 0.400
PRE_DURATION = 0.250
TOTAL_DURATION = 1.25
def main():
for folder in hf.get_subfolder_paths("data/"):
filepath = folder + "/fispikes1.dat"
set_savepath("figures/" + folder.split('/')[1] + "/")
print("Folder:", folder)
if not os.path.exists(get_savepath()):
os.makedirs(get_savepath())
spiketimes = []
intensities = []
index = -1
for metadata, key, data in dl.iload(filepath):
# embed()
if len(metadata) != 0:
metadata_index = 0
if '----- Control --------------------------------------------------------' in metadata[0].keys():
metadata_index = 1
print(metadata)
i = float(metadata[metadata_index]['intensity'][:-2])
intensities.append(i)
spiketimes.append([])
index += 1
spiketimes[index].append(data[:, 0]/1000)
intensities, spiketimes = hf.merge_similar_intensities(intensities, spiketimes)
# Sort the lists so that intensities are increasing
x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
intensities = x[0]
spiketimes = x[1]
mean_frequencies = calculate_mean_frequencies(intensities, spiketimes)
popt, pcov = fit_exponential(intensities, mean_frequencies)
plot_frequency_curve(intensities, mean_frequencies)
f_baseline = calculate_f_baseline(mean_frequencies)
f_infinity = calculate_f_infinity(mean_frequencies)
f_zero = calculate_f_zero(mean_frequencies)
# plot_fi_curve(intensities, f_baseline, f_zero, f_infinity)
# TODO !!
def fit_exponential(intensities, mean_frequencies):
start_idx = int((PRE_DURATION + STIMULUS_START+0.005) / SAMPLING_INTERVAL)
end_idx = int((PRE_DURATION + STIMULUS_START + 0.1) / SAMPLING_INTERVAL)
time_constants = []
#print(start_idx, end_idx)
popts = []
pcovs = []
for i in range(len(mean_frequencies)):
freq = mean_frequencies[i]
y_values = freq[start_idx:end_idx+1]
x_values = np.arange(start_idx*SAMPLING_INTERVAL, end_idx*SAMPLING_INTERVAL, SAMPLING_INTERVAL)
try:
popt, pcov = curve_fit(exponential_function, x_values, y_values, p0=(1/(np.power(1, 10)), .5, 50, 180), maxfev=10000)
except RuntimeError:
print("RuntimeError happened in fit_exponential.")
continue
#print(popt)
#print(pcov)
#print()
popts.append(popt)
pcovs.append(pcov)
plt.plot(np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL), freq)
plt.plot(x_values-PRE_DURATION, [exponential_function(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values])
# plt.show()
save_path = get_savepath() + "exponential_fits/"
if not os.path.exists(save_path):
os.makedirs(save_path)
plt.savefig(save_path + "fit_intensity:" + str(round(intensities[i], 4)) + ".png")
plt.close()
return popts, pcovs
def calculate_mean_frequency(freqs):
mean_freq = [sum(e) / len(e) for e in zip(*freqs)]
return mean_freq
def gaussian_kernel(sigma, dt):
x = np.arange(-4. * sigma, 4. * sigma, dt)
y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
return y
def calculate_kernel_frequency(spiketimes, time, sampling_interval):
sp = spiketimes
t = time # Probably goes from -200ms to some amount of ms in the positive ~1200?
dt = sampling_interval
kernel_width = 0.01 # kernel width is a time in seconds how sharp the frequency should be counted
binary = np.zeros(t.shape)
spike_indices = ((sp - t[0]) / dt).astype(int)
binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
g = gaussian_kernel(kernel_width, dt)
rate = np.convolve(binary, g, mode='same')
return rate
def calculate_isi_frequency(spiketimes, time):
first_isi = spiketimes[0] - (-PRE_DURATION) # diff to the start at 0
last_isi = TOTAL_DURATION - spiketimes[-1] # diff from the last spike to the end of time :D
isis = [first_isi]
isis.extend(np.diff(spiketimes))
isis.append(last_isi)
if np.isnan(first_isi):
print(spiketimes[:10])
print(isis[0:10])
quit()
rate = []
for isi in isis:
if isi == 0:
print("probably a problem")
isi = 0.0000000001
freq = 1/isi
frequency_step = int(round(isi*(1/SAMPLING_INTERVAL)))*[freq]
rate.extend(frequency_step)
#plt.plot((np.arange(len(rate))-PRE_DURATION)/(1/SAMPLING_INTERVAL), rate)
#plt.plot([sum(isis[:i+1]) for i in range(len(isis))], [200 for i in isis], 'o')
#plt.plot(time, [100 for t in time])
#plt.show()
if len(rate) != len(time):
if "12-13-af" in get_savepath():
warnings.warn("preStimulus duration > 0 still not supported")
return [1]*len(time)
else:
print(len(rate), len(time), len(rate) - len(time))
print(rate)
print(isis)
print("Quitting because time and rate aren't the same length")
quit()
return rate
def calculate_mean_frequencies(intensities, spiketimes):
time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL)
mean_frequencies = []
for i in range(len(intensities)):
freqs = []
for spikes in spiketimes[i]:
if len(spikes) < 2:
continue
freq = calculate_isi_frequency(spikes, time)
freqs.append(freq)
mf = calculate_mean_frequency(freqs)
mean_frequencies.append(mf)
return mean_frequencies
def calculate_f_baseline(mean_frequencies):
buffer_time = 0.05
start_idx = int(0.05/SAMPLING_INTERVAL)
end_idx = int((PRE_DURATION - STIMULUS_START - buffer_time)/SAMPLING_INTERVAL)
f_zeros = []
for freq in mean_frequencies:
f_0 = np.mean(freq[start_idx:end_idx])
f_zeros.append(f_0)
return f_zeros
def calculate_f_infinity(mean_frequencies):
buffer_time = 0.05
start_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - 0.15 - buffer_time) / SAMPLING_INTERVAL)
end_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - buffer_time) / SAMPLING_INTERVAL)
f_infinity = []
for freq in mean_frequencies:
f_inf = np.mean(freq[start_idx:end_idx])
f_infinity.append(f_inf)
return f_infinity
def calculate_f_zero(mean_frequencies):
buffer_time = 0.1
start_idx = int((PRE_DURATION + STIMULUS_START - buffer_time) / SAMPLING_INTERVAL)
end_idx = int((PRE_DURATION + STIMULUS_START + buffer_time) / SAMPLING_INTERVAL)
f_peaks = []
for freq in mean_frequencies:
fp = np.mean(freq[start_idx-500:start_idx])
for i in range(start_idx+1, end_idx):
if abs(freq[i] - freq[start_idx]) > abs(fp - freq[start_idx]):
fp = freq[i]
f_peaks.append(fp)
return f_peaks
def plot_fi_curve(intensities, f_baseline, f_zero, f_infinity):
plt.plot(intensities, f_baseline, label="f_baseline")
plt.plot(intensities, f_zero, 'o', label="f_zero")
plt.plot(intensities, f_infinity, label="f_infinity")
max_f0 = float(max(f_zero))
mean_int = float(np.mean(intensities))
start_k = float(((f_zero[-1] - f_zero[0]) / (intensities[-1] - intensities[0])*4)/f_zero[-1])
popt, pcov = curve_fit(fill_boltzmann, intensities, f_zero, p0=(max_f0, start_k, mean_int), maxfev=10000)
print(popt)
min_x = min(intensities)
max_x = max(intensities)
step = (max_x - min_x) / 5000
x_values_boltzmann_fit = np.arange(min_x, max_x, step)
plt.plot(x_values_boltzmann_fit, [fill_boltzmann(i, popt[0], popt[1], popt[2]) for i in x_values_boltzmann_fit], label='fit')
plt.title("FI-Curve")
plt.ylabel("Frequency in Hz")
plt.xlabel("Intensity in mV")
plt.legend()
# plt.show()
plt.savefig(get_savepath() + "fi_curve.png")
plt.close()
def plot_frequency_curve(intensities, mean_frequencies):
colors = ["red", "green", "blue", "violet", "orange", "grey"]
time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL)
for i in range(len(intensities)):
plt.plot(time, mean_frequencies[i], color=colors[i % 6], label=str(intensities[i]))
plt.plot((0, 0), (0, 500), color="black")
plt.plot((0.4, 0.4), (0, 500), color="black")
plt.legend()
plt.xlabel("Time in seconds")
plt.ylabel("Frequency in Hz")
plt.title("Frequency curve")
plt.savefig(get_savepath() + "mean_frequency_curves.png")
plt.close()
def exponential_function(x, a, b, c, d):
return a*np.exp(-c*(x-b))+d
def upper_boltzmann(x, f_max, k, x_zero):
return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None)
def fill_boltzmann(x, f_max, k, x_zero):
return f_max * (1 / (1 + np.power(np.e, -k * (x - x_zero))))
SAVEPATH = ""
def get_savepath():
global SAVEPATH
return SAVEPATH
def set_savepath(new_path):
global SAVEPATH
SAVEPATH = new_path
if __name__ == '__main__':
main()