import numpy as np import matplotlib.pyplot as plt import pyrelacs.DataLoader as dl import os from my_util import helperFunctions as hf from scipy.optimize import curve_fit import warnings SAMPLING_INTERVAL = 1/20000 STIMULUS_START = 0 STIMULUS_DURATION = 0.400 PRE_DURATION = 0.250 TOTAL_DURATION = 1.25 def main(): for folder in hf.get_subfolder_paths("data/"): filepath = folder + "/fispikes1.dat" set_savepath("figures/" + folder.split('/')[1] + "/") print("Folder:", folder) if not os.path.exists(get_savepath()): os.makedirs(get_savepath()) spiketimes = [] intensities = [] index = -1 for metadata, key, data in dl.iload(filepath): # embed() if len(metadata) != 0: metadata_index = 0 if '----- Control --------------------------------------------------------' in metadata[0].keys(): metadata_index = 1 print(metadata) i = float(metadata[metadata_index]['intensity'][:-2]) intensities.append(i) spiketimes.append([]) index += 1 spiketimes[index].append(data[:, 0]/1000) intensities, spiketimes = hf.merge_similar_intensities(intensities, spiketimes) # Sort the lists so that intensities are increasing x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))] intensities = x[0] spiketimes = x[1] mean_frequencies = calculate_mean_frequencies(intensities, spiketimes) popt, pcov = fit_exponential(intensities, mean_frequencies) plot_frequency_curve(intensities, mean_frequencies) f_baseline = calculate_f_baseline(mean_frequencies) f_infinity = calculate_f_infinity(mean_frequencies) f_zero = calculate_f_zero(mean_frequencies) # plot_fi_curve(intensities, f_baseline, f_zero, f_infinity) # TODO !! def fit_exponential(intensities, mean_frequencies): start_idx = int((PRE_DURATION + STIMULUS_START+0.005) / SAMPLING_INTERVAL) end_idx = int((PRE_DURATION + STIMULUS_START + 0.1) / SAMPLING_INTERVAL) time_constants = [] #print(start_idx, end_idx) popts = [] pcovs = [] for i in range(len(mean_frequencies)): freq = mean_frequencies[i] y_values = freq[start_idx:end_idx+1] x_values = np.arange(start_idx*SAMPLING_INTERVAL, end_idx*SAMPLING_INTERVAL, SAMPLING_INTERVAL) try: popt, pcov = curve_fit(exponential_function, x_values, y_values, p0=(1/(np.power(1, 10)), .5, 50, 180), maxfev=10000) except RuntimeError: print("RuntimeError happened in fit_exponential.") continue #print(popt) #print(pcov) #print() popts.append(popt) pcovs.append(pcov) plt.plot(np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL), freq) plt.plot(x_values-PRE_DURATION, [exponential_function(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values]) # plt.show() save_path = get_savepath() + "exponential_fits/" if not os.path.exists(save_path): os.makedirs(save_path) plt.savefig(save_path + "fit_intensity:" + str(round(intensities[i], 4)) + ".png") plt.close() return popts, pcovs def calculate_mean_frequency(freqs): mean_freq = [sum(e) / len(e) for e in zip(*freqs)] return mean_freq def gaussian_kernel(sigma, dt): x = np.arange(-4. * sigma, 4. * sigma, dt) y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma return y def calculate_kernel_frequency(spiketimes, time, sampling_interval): sp = spiketimes t = time # Probably goes from -200ms to some amount of ms in the positive ~1200? dt = sampling_interval kernel_width = 0.01 # kernel width is a time in seconds how sharp the frequency should be counted binary = np.zeros(t.shape) spike_indices = ((sp - t[0]) / dt).astype(int) binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1 g = gaussian_kernel(kernel_width, dt) rate = np.convolve(binary, g, mode='same') return rate def calculate_isi_frequency(spiketimes, time): first_isi = spiketimes[0] - (-PRE_DURATION) # diff to the start at 0 last_isi = TOTAL_DURATION - spiketimes[-1] # diff from the last spike to the end of time :D isis = [first_isi] isis.extend(np.diff(spiketimes)) isis.append(last_isi) if np.isnan(first_isi): print(spiketimes[:10]) print(isis[0:10]) quit() rate = [] for isi in isis: if isi == 0: print("probably a problem") isi = 0.0000000001 freq = 1/isi frequency_step = int(round(isi*(1/SAMPLING_INTERVAL)))*[freq] rate.extend(frequency_step) #plt.plot((np.arange(len(rate))-PRE_DURATION)/(1/SAMPLING_INTERVAL), rate) #plt.plot([sum(isis[:i+1]) for i in range(len(isis))], [200 for i in isis], 'o') #plt.plot(time, [100 for t in time]) #plt.show() if len(rate) != len(time): if "12-13-af" in get_savepath(): warnings.warn("preStimulus duration > 0 still not supported") return [1]*len(time) else: print(len(rate), len(time), len(rate) - len(time)) print(rate) print(isis) print("Quitting because time and rate aren't the same length") quit() return rate def calculate_mean_frequencies(intensities, spiketimes): time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL) mean_frequencies = [] for i in range(len(intensities)): freqs = [] for spikes in spiketimes[i]: if len(spikes) < 2: continue freq = calculate_isi_frequency(spikes, time) freqs.append(freq) mf = calculate_mean_frequency(freqs) mean_frequencies.append(mf) return mean_frequencies def calculate_f_baseline(mean_frequencies): buffer_time = 0.05 start_idx = int(0.05/SAMPLING_INTERVAL) end_idx = int((PRE_DURATION - STIMULUS_START - buffer_time)/SAMPLING_INTERVAL) f_zeros = [] for freq in mean_frequencies: f_0 = np.mean(freq[start_idx:end_idx]) f_zeros.append(f_0) return f_zeros def calculate_f_infinity(mean_frequencies): buffer_time = 0.05 start_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - 0.15 - buffer_time) / SAMPLING_INTERVAL) end_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - buffer_time) / SAMPLING_INTERVAL) f_infinity = [] for freq in mean_frequencies: f_inf = np.mean(freq[start_idx:end_idx]) f_infinity.append(f_inf) return f_infinity def calculate_f_zero(mean_frequencies): buffer_time = 0.1 start_idx = int((PRE_DURATION + STIMULUS_START - buffer_time) / SAMPLING_INTERVAL) end_idx = int((PRE_DURATION + STIMULUS_START + buffer_time) / SAMPLING_INTERVAL) f_peaks = [] for freq in mean_frequencies: fp = np.mean(freq[start_idx-500:start_idx]) for i in range(start_idx+1, end_idx): if abs(freq[i] - freq[start_idx]) > abs(fp - freq[start_idx]): fp = freq[i] f_peaks.append(fp) return f_peaks def plot_fi_curve(intensities, f_baseline, f_zero, f_infinity): plt.plot(intensities, f_baseline, label="f_baseline") plt.plot(intensities, f_zero, 'o', label="f_zero") plt.plot(intensities, f_infinity, label="f_infinity") max_f0 = float(max(f_zero)) mean_int = float(np.mean(intensities)) start_k = float(((f_zero[-1] - f_zero[0]) / (intensities[-1] - intensities[0])*4)/f_zero[-1]) popt, pcov = curve_fit(fill_boltzmann, intensities, f_zero, p0=(max_f0, start_k, mean_int), maxfev=10000) print(popt) min_x = min(intensities) max_x = max(intensities) step = (max_x - min_x) / 5000 x_values_boltzmann_fit = np.arange(min_x, max_x, step) plt.plot(x_values_boltzmann_fit, [fill_boltzmann(i, popt[0], popt[1], popt[2]) for i in x_values_boltzmann_fit], label='fit') plt.title("FI-Curve") plt.ylabel("Frequency in Hz") plt.xlabel("Intensity in mV") plt.legend() # plt.show() plt.savefig(get_savepath() + "fi_curve.png") plt.close() def plot_frequency_curve(intensities, mean_frequencies): colors = ["red", "green", "blue", "violet", "orange", "grey"] time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL) for i in range(len(intensities)): plt.plot(time, mean_frequencies[i], color=colors[i % 6], label=str(intensities[i])) plt.plot((0, 0), (0, 500), color="black") plt.plot((0.4, 0.4), (0, 500), color="black") plt.legend() plt.xlabel("Time in seconds") plt.ylabel("Frequency in Hz") plt.title("Frequency curve") plt.savefig(get_savepath() + "mean_frequency_curves.png") plt.close() def exponential_function(x, a, b, c, d): return a*np.exp(-c*(x-b))+d def upper_boltzmann(x, f_max, k, x_zero): return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None) def fill_boltzmann(x, f_max, k, x_zero): return f_max * (1 / (1 + np.power(np.e, -k * (x - x_zero)))) SAVEPATH = "" def get_savepath(): global SAVEPATH return SAVEPATH def set_savepath(new_path): global SAVEPATH SAVEPATH = new_path if __name__ == '__main__': main()