update test.py & new: useful_functions.py
This commit is contained in:
		
							parent
							
								
									8ccd633f85
								
							
						
					
					
						commit
						2f5a1d2754
					
				
							
								
								
									
										173
									
								
								code/useful_functions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								code/useful_functions.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,173 @@ | ||||
| import glob | ||||
| import pathlib | ||||
| import numpy as np | ||||
| import matplotlib.pyplot as plt | ||||
| import rlxnix as rlx | ||||
| from IPython import embed | ||||
| from scipy.signal import welch | ||||
| 
 | ||||
| def AM(EODf, stimulus): | ||||
|     """ | ||||
|     Calculates the Amplitude Modulation and Nyquist frequency | ||||
| 
 | ||||
|     Parameters | ||||
|     ---------- | ||||
|     EODf : float or int | ||||
|         The current EODf. | ||||
|     stimulus : float or int | ||||
|         The absolute frequency of the stimulus. | ||||
| 
 | ||||
|     Returns | ||||
|     ------- | ||||
|     AM : float  | ||||
|         The amplitude modulation resulting from the stimulus. | ||||
|     nyquist : float | ||||
|         The maximum frequency possible to resolve with the EODf. | ||||
| 
 | ||||
|     """ | ||||
|     nyquist = EODf * 0.5 | ||||
|     AM = np.mod(stimulus, nyquist) | ||||
|     return AM, nyquist | ||||
| 
 | ||||
| def binary_spikes(spike_times, duration, dt): | ||||
|     """ | ||||
|     Converts the spike times to a binary representations | ||||
| 
 | ||||
|     Parameters | ||||
|     ---------- | ||||
|     spike_times : np.array | ||||
|         The spike times. | ||||
|     duration : float | ||||
|         The trial duration: | ||||
|     dt : float | ||||
|         The temporal resolution. | ||||
| 
 | ||||
|     Returns | ||||
|     ------- | ||||
|     binary : np.array | ||||
|         The binary representation of the spike train. | ||||
| 
 | ||||
|     """ | ||||
|     binary = np.zeros(int(np.round(duration / dt))) #create the binary array with the same length as potential | ||||
|      | ||||
|     spike_indices = np.asarray(np.round(spike_times / dt), dtype = int) # get the indices | ||||
|     binary[spike_indices] = 1 # put the indices into binary | ||||
|     return binary | ||||
| 
 | ||||
| def extract_stim_data(stimulus): | ||||
|     ''' | ||||
|     extracts all necessary metadata for each stimulus | ||||
| 
 | ||||
|     Parameters | ||||
|     ---------- | ||||
|     stimulus : Stimulus object or rlxnix.base.repro module | ||||
|         The stimulus from which the data is needed. | ||||
| 
 | ||||
|     Returns | ||||
|     ------- | ||||
|     amplitude : float | ||||
|         The relative signal amplitude in percent. | ||||
|     df : float | ||||
|         Distance of the stimulus to the current EODf. | ||||
|     eodf : float | ||||
|         Current EODf. | ||||
|     stim_freq : float | ||||
|         The total stimulus frequency (EODF+df). | ||||
|     amp_mod : float | ||||
|         The current amplitude modulation. | ||||
|     ny_freq : float | ||||
|         The current nyquist frequency. | ||||
| 
 | ||||
|     ''' | ||||
|     # extract metadata | ||||
|     # the stim.name adjusts the first key as it changes with every stimulus | ||||
|     amplitude = stimulus.metadata[stimulus.name]['Contrast'][0][0]  | ||||
|     df = stimulus.metadata[stimulus.name]['DeltaF'][0][0] | ||||
|     eodf = round(stimulus.metadata[stimulus.name]['EODf'][0][0]) | ||||
|     stim_freq = round(stimulus.metadata[stimulus.name]['Frequency'][0][0]) | ||||
|     # calculates the amplitude modulation | ||||
|     amp_mod, ny_freq = AM(eodf, stim_freq) | ||||
|     return amplitude, df, eodf, stim_freq, amp_mod, ny_freq | ||||
| 
 | ||||
| def firing_rate(binary_spikes, dt = 0.000025, box_width = 0.01): | ||||
|     ''' | ||||
|     Calculates the firing rate from binary spikes | ||||
| 
 | ||||
|     Parameters | ||||
|     ---------- | ||||
|     binary_spikes : np.array | ||||
|         The binary representation of the spike train. | ||||
|     dt : float, optional | ||||
|         Time difference between two datapoints. The default is 0.000025. | ||||
|     box_width : float, optional | ||||
|         Time window on which the rate should be computed on. The default is 0.01. | ||||
| 
 | ||||
|     Returns | ||||
|     ------- | ||||
|     rate : np.array | ||||
|         Array of firing rates. | ||||
| 
 | ||||
|     ''' | ||||
|     box = np.ones(int(box_width // dt)) | ||||
|     box /= np.sum(box) * dt # normalisierung des box kernels to an integral of one | ||||
|     rate = np.convolve(binary_spikes, box, mode = 'same')  | ||||
|     return rate | ||||
| 
 | ||||
| def power_spectrum(spike_times, duration, dt): | ||||
|     ''' | ||||
|     Computes a power spectrum based on the spike times | ||||
| 
 | ||||
|     Parameters | ||||
|     ---------- | ||||
|     spike_times : np.array | ||||
|         The spike times. | ||||
|     duration : float | ||||
|         The trial duration: | ||||
|     dt : float | ||||
|         The temporal resolution. | ||||
| 
 | ||||
|     Returns | ||||
|     ------- | ||||
|     freq : np.array | ||||
|         All the frequencies of the power spectrum. | ||||
|     power : np.array | ||||
|         Power of the frequencies calculated. | ||||
| 
 | ||||
|     ''' | ||||
|     # binarizes spikes | ||||
|     binary = binary_spikes(spike_times, duration, dt) | ||||
|     # computes firing rates | ||||
|     rate = firing_rate(binary, dt = dt) | ||||
|     # creates power spectrum | ||||
|     freq, power = welch(rate, fs = 1/dt, nperseg = 2**16, noverlap = 2**15) | ||||
|     return freq, power | ||||
| 
 | ||||
| def remove_poor(files): | ||||
|     """ | ||||
|     Removes poor datasets from the set of files for analysis | ||||
| 
 | ||||
|     Parameters | ||||
|     ---------- | ||||
|     files : list  | ||||
|         list of files. | ||||
| 
 | ||||
|     Returns | ||||
|     ------- | ||||
|     good_files : list | ||||
|         list of files without the ones with the label poor. | ||||
| 
 | ||||
|     """ | ||||
|     # create list for good files | ||||
|     good_files = [] | ||||
|     # loop over files | ||||
|     for i in range(len(files)): | ||||
|         # print(files[i]) | ||||
|         # load the file (takes some time) | ||||
|         data = rlx.Dataset(files[i]) | ||||
|         # get the quality | ||||
|         quality = str.lower(data.metadata["Recording"]["Recording quality"][0][0]) | ||||
|         # check the quality | ||||
|         if quality != "poor": | ||||
|             # if its good or fair add it to the good files | ||||
|             good_files.append(files[i]) | ||||
|     return good_files | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user