463 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			463 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import glob
 | |
| import pathlib
 | |
| import numpy as np
 | |
| import matplotlib.pyplot as plt
 | |
| import rlxnix as rlx
 | |
| from IPython import embed
 | |
| from scipy.signal import welch
 | |
| 
 | |
| def all_coming_together(freq_array, power_array, points_list, categories, num_harmonics_list, colors, delta=2.5, threshold=0.5):
 | |
|     """
 | |
|     Process a list of points, calculating integrals, checking validity, and preparing harmonics for valid points.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     freq_array : np.array
 | |
|         Array of frequencies corresponding to the power values.
 | |
|     power_array : np.array
 | |
|         Array of power spectral density values.
 | |
|     points_list : list
 | |
|         List of harmonic frequency points to process.
 | |
|     categories : list
 | |
|         List of corresponding categories for each point.
 | |
|     num_harmonics_list : list
 | |
|         List of the number of harmonics for each point.
 | |
|     colors : list
 | |
|         List of colors corresponding to each point's category.
 | |
|     delta : float, optional
 | |
|         Radius of the range for integration around each point (default is 2.5).
 | |
|     threshold : float, optional
 | |
|         Threshold value to compare integrals with local mean (default is 0.5).
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     valid_points : list
 | |
|         A continuous list of harmonics for all valid points.
 | |
|     color_mapping : dict
 | |
|         A dictionary mapping categories to corresponding colors.
 | |
|     category_harmonics : dict
 | |
|         A mapping of categories to their harmonic frequencies.
 | |
|     messages : list
 | |
|         A list of messages for each point, stating whether it was valid or not.
 | |
|     """
 | |
|     valid_points = []  # A continuous list of harmonics for valid points
 | |
|     color_mapping = {}
 | |
|     category_harmonics = {}
 | |
|     messages = []
 | |
| 
 | |
|     for i, point in enumerate(points_list):
 | |
|         category = categories[i]
 | |
|         num_harmonics = num_harmonics_list[i]
 | |
|         color = colors[i]
 | |
|         
 | |
|         # Step 1: Calculate the integral for the point
 | |
|         integral, local_mean, _ = calculate_integral(freq_array, power_array, point, delta)
 | |
|         
 | |
|         # Step 2: Check if the point is valid
 | |
|         valid = valid_integrals(integral, local_mean, point, threshold)
 | |
|         if valid:
 | |
|             # Step 3: Prepare harmonics if the point is valid
 | |
|             harmonics, color_map, category_harm = prepare_harmonic(point, category, num_harmonics, color)
 | |
|             valid_points.extend(harmonics)  # Use extend() to append harmonics in a continuous manner
 | |
|             color_mapping.update(color_map)
 | |
|             category_harmonics.update(category_harm)
 | |
|             messages.append(f"The point {point} is valid.")
 | |
|         else:
 | |
|             messages.append(f"The point {point} is not valid.")
 | |
|     
 | |
|     return valid_points, color_mapping, category_harmonics, messages
 | |
| 
 | |
| 
 | |
| 
 | |
| def AM(EODf, stimulus):
 | |
|     """
 | |
|     Calculates the Amplitude Modulation and Nyquist frequency
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     EODf : float or int
 | |
|         The current EODf.
 | |
|     stimulus : float or int
 | |
|         The absolute frequency of the stimulus.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     AM : float 
 | |
|         The amplitude modulation resulting from the stimulus.
 | |
|     nyquist : float
 | |
|         The maximum frequency possible to resolve with the EODf.
 | |
| 
 | |
|     """
 | |
|     nyquist = EODf * 0.5
 | |
|     AM = np.mod(stimulus, nyquist)
 | |
|     return AM, nyquist
 | |
| 
 | |
| def binary_spikes(spike_times, duration, dt):
 | |
|     """
 | |
|     Converts the spike times to a binary representations
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     spike_times : np.array
 | |
|         The spike times.
 | |
|     duration : float
 | |
|         The trial duration:
 | |
|     dt : float
 | |
|         The temporal resolution.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     binary : np.array
 | |
|         The binary representation of the spike train.
 | |
| 
 | |
|     """
 | |
|     binary = np.zeros(int(np.round(duration / dt))) #create the binary array with the same length as potential
 | |
|     
 | |
|     spike_indices = np.asarray(np.round(spike_times / dt), dtype = int) # get the indices
 | |
|     binary[spike_indices] = 1 # put the indices into binary
 | |
|     return binary
 | |
| 
 | |
| def calculate_integral(freq, power, point, delta = 2.5):   
 | |
|     """
 | |
|     Calculate the integral around a single specified point.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     frequency : np.array
 | |
|         An array of frequencies corresponding to the power values.
 | |
|     power : np.array
 | |
|         An array of power spectral density values.
 | |
|     point : float
 | |
|         The harmonic frequency at which to calculate the integral.
 | |
|     delta : float, optional
 | |
|         Radius of the range for integration around the point. The default is 2.5.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     integral : float
 | |
|         The calculated integral around the point.
 | |
|     local_mean : float
 | |
|         The local mean value (adjacent integrals).
 | |
|     p_power : float
 | |
|         The local maxiumum power.
 | |
|     """
 | |
|     indices = (freq >= point - delta) & (freq <= point + delta)
 | |
|     integral = np.trapz(power[indices], freq[indices])
 | |
|     p_power = np.max(power[indices])
 | |
|             
 | |
|     left_indices = (freq >= point - 5 * delta) & (freq < point - delta)
 | |
|     right_indices = (freq > point + delta) & (freq <= point + 5 * delta)
 | |
|            
 | |
|     l_integral = np.trapz(power[left_indices], freq[left_indices])
 | |
|     r_integral = np.trapz(power[right_indices], freq[right_indices])
 | |
|            
 | |
|     local_mean = np.mean([l_integral, r_integral])
 | |
|     return integral, local_mean, p_power
 | |
| 
 | |
| def extract_stim_data(stimulus):
 | |
|     '''
 | |
|     extracts all necessary metadata for each stimulus
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     stimulus : Stimulus object or rlxnix.base.repro module
 | |
|         The stimulus from which the data is needed.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     amplitude : float
 | |
|         The relative signal amplitude in percent.
 | |
|     df : float
 | |
|         Distance of the stimulus to the current EODf.
 | |
|     eodf : float
 | |
|         Current EODf.
 | |
|     stim_freq : float
 | |
|         The total stimulus frequency (EODF+df).
 | |
|     stim_dur : float
 | |
|         The stimulus duration.
 | |
|     amp_mod : float
 | |
|         The current amplitude modulation.
 | |
|     ny_freq : float
 | |
|         The current nyquist frequency.
 | |
| 
 | |
|     '''
 | |
|     # extract metadata
 | |
|     # the stim.name adjusts the first key as it changes with every stimulus
 | |
|     amplitude = stimulus.metadata[stimulus.name]['Contrast'][0][0] 
 | |
|     df = stimulus.metadata[stimulus.name]['DeltaF'][0][0]
 | |
|     eodf = round(stimulus.metadata[stimulus.name]['EODf'][0][0])
 | |
|     stim_freq = round(stimulus.metadata[stimulus.name]['Frequency'][0][0])
 | |
|     stim_dur = stimulus.duration
 | |
|     # calculates the amplitude modulation
 | |
|     amp_mod, ny_freq = AM(eodf, stim_freq)
 | |
|     return amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq
 | |
| 
 | |
| def find_exceeding_points(frequency, power, points, delta, threshold):
 | |
|     """
 | |
|     Find the points where the integral exceeds the local mean by a given threshold.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     frequency : np.array
 | |
|         An array of frequencies corresponding to the power values.
 | |
|     power : np.array
 | |
|         An array of power spectral density values.
 | |
|     points : list
 | |
|         A list of harmonic frequencies to evaluate.
 | |
|     delta : float
 | |
|         Half-width of the range for integration around the point.
 | |
|     threshold : float
 | |
|         Threshold value to compare integrals with local mean.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     exceeding_points : list
 | |
|         A list of points where the integral exceeds the local mean by the threshold.
 | |
|     """
 | |
|     exceeding_points = []
 | |
|     
 | |
|     for point in points:
 | |
|         # Calculate the integral and local mean for the current point
 | |
|         integral, local_mean = calculate_integral(frequency, power, point, delta)
 | |
|         
 | |
|         # Check if the integral exceeds the threshold
 | |
|         valid, message = valid_integrals(integral, local_mean, threshold, point)
 | |
|         
 | |
|         if valid:
 | |
|             exceeding_points.append(point)
 | |
|     
 | |
|     return exceeding_points
 | |
| 
 | |
| def firing_rate(binary_spikes, dt = 0.000025, box_width = 0.01):
 | |
|     '''
 | |
|     Calculates the firing rate from binary spikes
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     binary_spikes : np.array
 | |
|         The binary representation of the spike train.
 | |
|     dt : float, optional
 | |
|         Time difference between two datapoints. The default is 0.000025.
 | |
|     box_width : float, optional
 | |
|         Time window on which the rate should be computed on. The default is 0.01.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     rate : np.array
 | |
|         Array of firing rates.
 | |
| 
 | |
|     '''
 | |
|     box = np.ones(int(box_width // dt))
 | |
|     box /= np.sum(box) * dt # normalisierung des box kernels to an integral of one
 | |
|     rate = np.convolve(binary_spikes, box, mode = 'same') 
 | |
|     return rate
 | |
| 
 | |
| def power_spectrum(stimulus):
 | |
|     '''
 | |
|     Computes a power spectrum based from a stimulus
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     stimulus : Stimulus object or rlxnix.base.repro module
 | |
|         The stimulus for which the data is needed.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     freq : np.array
 | |
|         All the frequencies of the power spectrum.
 | |
|     power : np.array
 | |
|         Power of the frequencies calculated.
 | |
| 
 | |
|     '''
 | |
|     spikes, duration, dt = spike_times(stimulus)
 | |
|     # binarizes spikes
 | |
|     binary = binary_spikes(spikes, duration, dt)
 | |
|     # computes firing rates
 | |
|     rate = firing_rate(binary, dt = dt)
 | |
|     # creates power spectrum
 | |
|     freq, power = welch(rate, fs = 1/dt, nperseg = 2**16, noverlap = 2**15)
 | |
|     return freq, power
 | |
| 
 | |
| def prepare_harmonic(frequency, category, num_harmonics, color):
 | |
|     """
 | |
|     Prepare harmonic frequencies and assign color based on category for a single point.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     frequency : float
 | |
|         Base frequency to generate harmonics.
 | |
|     category : str
 | |
|         Corresponding category for the base frequency.
 | |
|     num_harmonics : int
 | |
|         Number of harmonics for the base frequency.
 | |
|     color : str
 | |
|         Color corresponding to the category.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     harmonics : list
 | |
|         A list of harmonic frequencies.
 | |
|     color_mapping : dict
 | |
|         A dictionary mapping the category to its corresponding color.
 | |
|     category_harmonics : dict
 | |
|         A mapping of the category to its harmonic frequencies.
 | |
|     """
 | |
|     harmonics = [frequency * (i + 1) for i in range(num_harmonics)]
 | |
|     
 | |
|     color_mapping = {category: color}
 | |
|     category_harmonics = {category: harmonics}
 | |
| 
 | |
|     return harmonics, color_mapping, category_harmonics
 | |
| 
 | |
| def remove_poor(files):
 | |
|     """
 | |
|     Removes poor datasets from the set of files for analysis
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     files : list 
 | |
|         list of files.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     good_files : list
 | |
|         list of files without the ones with the label poor.
 | |
| 
 | |
|     """
 | |
|     # create list for good files
 | |
|     good_files = []
 | |
|     # loop over files
 | |
|     for i in range(len(files)):
 | |
|         # print(files[i])
 | |
|         # load the file (takes some time)
 | |
|         data = rlx.Dataset(files[i])
 | |
|         # get the quality
 | |
|         quality = str.lower(data.metadata["Recording"]["Recording quality"][0][0])
 | |
|         # check the quality
 | |
|         if quality != "poor":
 | |
|             # if its good or fair add it to the good files
 | |
|             good_files.append(files[i])
 | |
|     return good_files
 | |
| 
 | |
| def sam_data(sam):
 | |
|     '''
 | |
|     Gets metadata for each SAM
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     sam : ReproRun object
 | |
|         The sam the metdata should be extracted from.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     avg_dur : float
 | |
|         Average stimulus duarion.
 | |
|     sam_amp : float
 | |
|         amplitude in percent, relative to the fish amplitude.
 | |
|     sam_am : float
 | |
|         Amplitude modulation frequency.
 | |
|     sam_df : float
 | |
|         Difference from the stimulus to the current fish eodf.
 | |
|     sam_eodf : float
 | |
|         The current EODf.
 | |
|     sam_nyquist : float
 | |
|         The Nyquist frequency of the EODf.
 | |
|     sam_stim : float
 | |
|         The stimulus frequency.
 | |
| 
 | |
|     '''
 | |
|     # create lists for the values we want
 | |
|     amplitudes = []
 | |
|     dfs = []
 | |
|     eodfs = []
 | |
|     stim_freqs = []
 | |
|     amp_mods = []
 | |
|     ny_freqs = []
 | |
|     durations = []
 | |
|     
 | |
|     # get the stimuli
 | |
|     stimuli = sam.stimuli
 | |
|     
 | |
|     # loop over the stimuli
 | |
|     for stim in stimuli:
 | |
|         amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq = extract_stim_data(stim)
 | |
|         amplitudes.append(amplitude)
 | |
|         dfs.append(df)
 | |
|         eodfs.append(eodf)
 | |
|         stim_freqs.append(stim_freq)
 | |
|         amp_mods.append(amp_mod)
 | |
|         ny_freqs.append(ny_freq)
 | |
|         durations.append(stim_dur)
 | |
|       
 | |
|     # get the means
 | |
|     sam_amp = np.mean(amplitudes)
 | |
|     sam_am = np.mean(amp_mods)
 | |
|     sam_df = np.mean(dfs)
 | |
|     sam_eodf = np.mean(eodfs)
 | |
|     sam_nyquist = np.mean(ny_freqs)
 | |
|     sam_stim = np.mean(stim_freqs)
 | |
|     avg_dur = np.mean(durations)
 | |
|     return avg_dur, sam_amp, sam_am, sam_df, sam_eodf, sam_nyquist, sam_stim
 | |
| 
 | |
| def spike_times(stim):
 | |
|     """
 | |
|     Reads out the spike times and other necessary parameters
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     stim : Stimulus object or rlxnix.base.repro module
 | |
|         The stimulus from which the spike times should be calculated.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     spike_times : np.array
 | |
|         The spike times of the stimulus.
 | |
|     stim_dur : float
 | |
|         The duration of the stimulus.
 | |
|     dt : float
 | |
|         Time interval between two data points.
 | |
| 
 | |
|     """
 | |
|     # reads out the spike times
 | |
|     spikes, _ = stim.trace_data('Spikes-1')
 | |
|     # reads out the duration
 | |
|     stim_dur = stim.duration
 | |
|     # get the stimulus interval
 | |
|     ti = stim.trace_info("V-1")
 | |
|     dt = ti.sampling_interval  
 | |
|     return spikes, stim_dur, dt # se changed spike_times to spikes so its not the same as name of function
 | |
| 
 | |
| def valid_integrals(integral, local_mean, point, threshold = 0.1):
 | |
|     """
 | |
|     Check if the integral exceeds the threshold compared to the local mean and 
 | |
|     provide feedback on whether the given point is valid or not.
 | |
| 
 | |
|     Parameters
 | |
|     ----------
 | |
|     integral : float
 | |
|         The calculated integral around the point.
 | |
|     local_mean : float
 | |
|         The local mean value (adjacent integrals).
 | |
|     threshold : float
 | |
|         Threshold value to compare integrals with local mean.
 | |
|     point : float
 | |
|         The harmonic frequency point being evaluated.
 | |
| 
 | |
|     Returns
 | |
|     -------
 | |
|     valid : bool
 | |
|         True if the integral exceeds the local mean by the threshold, otherwise False.
 | |
|     """
 | |
|     valid = integral > (local_mean * (1 + threshold))
 | |
|     if valid:
 | |
|         print(f"The point {point} is valid, as its integral exceeds the threshold.")
 | |
|     else:
 | |
|         print(f"The point {point} is not valid, as its integral does not exceed the threshold.")
 | |
|     return valid
 | |
| 
 | |
| '''TODO Sarah: AM-freq plot:
 | |
|     meaning of am peak in spectrum? why is it there how does it change with stim intensity? 
 | |
|     make plot with AM 1/2 EODf over stim frequency (df+eodf), get amplitude of am peak and plot
 | |
|     amplitude over frequency of peak'''
 | |
|     
 | |
| """ files = glob.glob("../data/2024-10-16*.nix") gets all the filepaths from the 16.10""" |