gpgrewe2024/code/useful_functions.py

549 lines
17 KiB
Python

import glob
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import rlxnix as rlx
from IPython import embed
from scipy.signal import welch
def all_coming_together(freq_array, power_array, points_list, categories, num_harmonics_list, colors, delta=2.5, threshold=0.5):
"""
Process a list of points, calculating integrals, checking validity, and preparing harmonics for valid points.
Parameters
----------
freq_array : np.array
Array of frequencies corresponding to the power values.
power_array : np.array
Array of power spectral density values.
points_list : list
List of harmonic frequency points to process.
categories : list
List of corresponding categories for each point.
num_harmonics_list : list
List of the number of harmonics for each point.
colors : list
List of colors corresponding to each point's category.
delta : float, optional
Radius of the range for integration around each point (default is 2.5).
threshold : float, optional
Threshold value to compare integrals with local mean (default is 0.5).
Returns
-------
valid_points : list
A continuous list of harmonics for all valid points.
color_mapping : dict
A dictionary mapping categories to corresponding colors.
category_harmonics : dict
A mapping of categories to their harmonic frequencies.
messages : list
A list of messages for each point, stating whether it was valid or not.
"""
valid_points = [] # A continuous list of harmonics for valid points
color_mapping = {}
category_harmonics = {}
messages = []
for i, point in enumerate(points_list):
category = categories[i]
num_harmonics = num_harmonics_list[i]
color = colors[i]
# Step 1: Calculate the integral for the point
integral, local_mean, _ = calculate_integral(freq_array, power_array, point, delta)
# Step 2: Check if the point is valid
valid = valid_integrals(integral, local_mean, point, threshold)
if valid:
# Step 3: Prepare harmonics if the point is valid
harmonics, color_map, category_harm = prepare_harmonic(point, category, num_harmonics, color)
valid_points.extend(harmonics) # Use extend() to append harmonics in a continuous manner
color_mapping.update(color_map)
category_harmonics.update(category_harm)
messages.append(f"The point {point} is valid.")
else:
messages.append(f"The point {point} is not valid.")
return valid_points, color_mapping, category_harmonics, messages
def AM(EODf, stimulus):
"""
Calculates the Amplitude Modulation and Nyquist frequency
Parameters
----------
EODf : float or int
The current EODf.
stimulus : float or int
The absolute frequency of the stimulus.
Returns
-------
AM : float
The amplitude modulation resulting from the stimulus.
nyquist : float
The maximum frequency possible to resolve with the EODf.
"""
nyquist = EODf * 0.5
AM = np.mod(stimulus, nyquist)
return AM, nyquist
def binary_spikes(spike_times, duration, dt):
"""
Converts the spike times to a binary representations
Parameters
----------
spike_times : np.array
The spike times.
duration : float
The trial duration:
dt : float
The temporal resolution.
Returns
-------
binary : np.array
The binary representation of the spike train.
"""
binary = np.zeros(int(np.round(duration / dt))) #create the binary array with the same length as potential
spike_indices = np.asarray(np.round(spike_times / dt), dtype = int) # get the indices
binary[spike_indices] = 1 # put the indices into binary
return binary
def calculate_integral(freq, power, point, delta = 2.5):
"""
Calculate the integral around a single specified point.
Parameters
----------
frequency : np.array
An array of frequencies corresponding to the power values.
power : np.array
An array of power spectral density values.
point : float
The harmonic frequency at which to calculate the integral.
delta : float, optional
Radius of the range for integration around the point. The default is 2.5.
Returns
-------
integral : float
The calculated integral around the point.
local_mean : float
The local mean value (adjacent integrals).
p_power : float
The local maxiumum power.
"""
indices = (freq >= point - delta) & (freq <= point + delta)
integral = np.trapz(power[indices], freq[indices])
p_power = np.max(power[indices])
left_indices = (freq >= point - 5 * delta) & (freq < point - delta)
right_indices = (freq > point + delta) & (freq <= point + 5 * delta)
l_integral = np.trapz(power[left_indices], freq[left_indices])
r_integral = np.trapz(power[right_indices], freq[right_indices])
local_mean = np.mean([l_integral, r_integral])
return integral, local_mean, p_power
def contrast_sorting(sams, con_1 = 20, con_2 = 10, con_3 = 5, stim_count = 3, stim_dur = 2):
'''
sorts the sams into three contrasts
Parameters
----------
sams : ReproRuns
The sams to be sorted.
con_1 : int, optional
the first contrast. The default is 20.
con_2 : int, optional
the second contrast. The default is 10.
con_3 : int, optional
the third contrast. The default is 5.
stim_count : int, optional
the amount of stimuli per sam in a good sam. The default is 3.
stim_dur : int, optional
The stimulus duration. The default is 2.
Returns
-------
contrast_sams : dictionary
A dictionary containing all sams sorted to the contrasts.
'''
# dictionary for the contrasts
contrast_sams = {con_1 : [],
con_2 : [],
con_3 : []}
# loop over all sams
for sam in sams:
# get the contrast
avg_dur, contrast, _, _, _, _, _ = sam_data(sam)
# check for valid trails
if np.isnan(contrast):
continue
elif sam.stimulus_count < stim_count: #aborted trials
continue
elif avg_dur < (stim_dur * 0.8):
continue
else:
contrast = int(contrast) # get integer of contrast
# sort them accordingly
if contrast == con_1:
contrast_sams[con_1].append(sam)
elif contrast == con_2:
contrast_sams[con_2].append(sam)
elif contrast == con_3:
contrast_sams[con_3].append(sam)
else:
continue
return contrast_sams
def extract_stim_data(stimulus):
'''
extracts all necessary metadata for each stimulus
Parameters
----------
stimulus : Stimulus object or rlxnix.base.repro module
The stimulus from which the data is needed.
Returns
-------
amplitude : float
The relative signal amplitude in percent.
df : float
Distance of the stimulus to the current EODf.
eodf : float
Current EODf.
stim_freq : float
The total stimulus frequency (EODF+df).
stim_dur : float
The stimulus duration.
amp_mod : float
The current amplitude modulation.
ny_freq : float
The current nyquist frequency.
'''
# extract metadata
# the stim.name adjusts the first key as it changes with every stimulus
amplitude = stimulus.metadata[stimulus.name]['Contrast'][0][0]
df = stimulus.metadata[stimulus.name]['DeltaF'][0][0]
eodf = round(stimulus.metadata[stimulus.name]['EODf'][0][0])
stim_freq = round(stimulus.metadata[stimulus.name]['Frequency'][0][0])
stim_dur = stimulus.duration
# calculates the amplitude modulation
amp_mod, ny_freq = AM(eodf, stim_freq)
return amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq
def find_exceeding_points(frequency, power, points, delta, threshold):
"""
Find the points where the integral exceeds the local mean by a given threshold.
Parameters
----------
frequency : np.array
An array of frequencies corresponding to the power values.
power : np.array
An array of power spectral density values.
points : list
A list of harmonic frequencies to evaluate.
delta : float
Half-width of the range for integration around the point.
threshold : float
Threshold value to compare integrals with local mean.
Returns
-------
exceeding_points : list
A list of points where the integral exceeds the local mean by the threshold.
"""
exceeding_points = []
for point in points:
# Calculate the integral and local mean for the current point
integral, local_mean = calculate_integral(frequency, power, point, delta)
# Check if the integral exceeds the threshold
valid, message = valid_integrals(integral, local_mean, threshold, point)
if valid:
exceeding_points.append(point)
return exceeding_points
def firing_rate(binary_spikes, dt = 0.000025, box_width = 0.01):
'''
Calculates the firing rate from binary spikes
Parameters
----------
binary_spikes : np.array
The binary representation of the spike train.
dt : float, optional
Time difference between two datapoints. The default is 0.000025.
box_width : float, optional
Time window on which the rate should be computed on. The default is 0.01.
Returns
-------
rate : np.array
Array of firing rates.
'''
box = np.ones(int(box_width // dt))
box /= np.sum(box) * dt # normalisierung des box kernels to an integral of one
rate = np.convolve(binary_spikes, box, mode = 'same')
return rate
def power_spectrum(stimulus):
'''
Computes a power spectrum based from a stimulus
Parameters
----------
stimulus : Stimulus object or rlxnix.base.repro module
The stimulus for which the data is needed.
Returns
-------
freq : np.array
All the frequencies of the power spectrum.
power : np.array
Power of the frequencies calculated.
'''
spikes, duration, dt = spike_times(stimulus)
# binarizes spikes
binary = binary_spikes(spikes, duration, dt)
# computes firing rates
rate = firing_rate(binary, dt = dt)
# creates power spectrum
freq, power = welch(binary, fs = 1/dt, nperseg = 2**16, noverlap = 2**15)
return freq, power
def prepare_harmonic(frequency, category, num_harmonics, color):
"""
Prepare harmonic frequencies and assign color based on category for a single point.
Parameters
----------
frequency : float
Base frequency to generate harmonics.
category : str
Corresponding category for the base frequency.
num_harmonics : int
Number of harmonics for the base frequency.
color : str
Color corresponding to the category.
Returns
-------
harmonics : list
A list of harmonic frequencies.
color_mapping : dict
A dictionary mapping the category to its corresponding color.
category_harmonics : dict
A mapping of the category to its harmonic frequencies.
"""
harmonics = [frequency * (i + 1) for i in range(num_harmonics)]
color_mapping = {category: color}
category_harmonics = {category: harmonics}
return harmonics, color_mapping, category_harmonics
def remove_poor(files):
"""
Removes poor datasets from the set of files for analysis
Parameters
----------
files : list
list of files.
Returns
-------
good_files : list
list of files without the ones with the label poor.
"""
# create list for good files
good_files = []
# loop over files
for i in range(len(files)):
# print(files[i])
# load the file (takes some time)
data = rlx.Dataset(files[i])
# get the quality
quality = str.lower(data.metadata["Recording"]["Recording quality"][0][0])
# check the quality
if quality != "poor":
# if its good or fair add it to the good files
good_files.append(files[i])
return good_files
def sam_data(sam):
'''
Gets metadata for each SAM
Parameters
----------
sam : ReproRun object
The sam the metdata should be extracted from.
Returns
-------
avg_dur : float
Average stimulus duarion.
sam_amp : float
amplitude in percent, relative to the fish amplitude.
sam_am : float
Amplitude modulation frequency.
sam_df : float
Difference from the stimulus to the current fish eodf.
sam_eodf : float
The current EODf.
sam_nyquist : float
The Nyquist frequency of the EODf.
sam_stim : float
The stimulus frequency.
'''
# create lists for the values we want
amplitudes = []
dfs = []
eodfs = []
stim_freqs = []
amp_mods = []
ny_freqs = []
durations = []
# get the stimuli
stimuli = sam.stimuli
# loop over the stimuli
for stim in stimuli:
amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq = extract_stim_data(stim)
amplitudes.append(amplitude)
dfs.append(df)
eodfs.append(eodf)
stim_freqs.append(stim_freq)
amp_mods.append(amp_mod)
ny_freqs.append(ny_freq)
durations.append(stim_dur)
# get the means
sam_amp = np.mean(amplitudes)
sam_am = np.mean(amp_mods)
sam_df = np.mean(dfs)
sam_eodf = np.mean(eodfs)
sam_nyquist = np.mean(ny_freqs)
sam_stim = np.mean(stim_freqs)
avg_dur = np.mean(durations)
return avg_dur, sam_amp, sam_am, sam_df, sam_eodf, sam_nyquist, sam_stim
def sam_spectrum(sam):
"""
Creates a power spectrum for a ReproRun of a SAM.
Parameters
----------
sam : ReproRun Object
The Reprorun the powerspectrum should be generated from.
Returns
-------
sam_frequency : np.array
The frequencies of the powerspectrum.
sam_power : np.array
The powers of the frequencies.
"""
stimuli = sam.stimuli
# lists for the power spectra
frequencies = []
powers = []
# loop over the stimuli
for stimulus in stimuli:
# get the powerspectrum for each stimuli
frequency, power = power_spectrum(stimulus)
# append the power spectrum data
frequencies.append(frequency)
powers.append(power)
#average over the stimuli
sam_frequency = np.mean(frequencies, axis = 0)
sam_power = np.mean(powers, axis = 0)
return sam_frequency, sam_power
def spike_times(stim):
"""
Reads out the spike times and other necessary parameters
Parameters
----------
stim : Stimulus object or rlxnix.base.repro module
The stimulus from which the spike times should be calculated.
Returns
-------
spike_times : np.array
The spike times of the stimulus.
stim_dur : float
The duration of the stimulus.
dt : float
Time interval between two data points.
"""
# reads out the spike times
spikes, _ = stim.trace_data('Spikes-1')
# reads out the duration
stim_dur = stim.duration
# get the stimulus interval
ti = stim.trace_info("V-1")
dt = ti.sampling_interval
return spikes, stim_dur, dt # se changed spike_times to spikes so its not the same as name of function
def valid_integrals(integral, local_mean, point, threshold = 0.1):
"""
Check if the integral exceeds the threshold compared to the local mean and
provide feedback on whether the given point is valid or not.
Parameters
----------
integral : float
The calculated integral around the point.
local_mean : float
The local mean value (adjacent integrals).
threshold : float
Threshold value to compare integrals with local mean.
point : float
The harmonic frequency point being evaluated.
Returns
-------
valid : bool
True if the integral exceeds the local mean by the threshold, otherwise False.
"""
valid = integral > (local_mean * (1 + threshold))
if valid:
print(f"The point {point} is valid, as its integral exceeds the threshold.")
else:
print(f"The point {point} is not valid, as its integral does not exceed the threshold.")
return valid
'''TODO Sarah: AM-freq plot:
meaning of am peak in spectrum? why is it there how does it change with stim intensity?
make plot with AM 1/2 EODf over stim frequency (df+eodf), get amplitude of am peak and plot
amplitude over frequency of peak'''
""" files = glob.glob("../data/2024-10-16*.nix") gets all the filepaths from the 16.10"""