diff --git a/code/am_plots_modularized.py b/code/am_plots_modularized.py new file mode 100644 index 0000000..d56eaef --- /dev/null +++ b/code/am_plots_modularized.py @@ -0,0 +1,162 @@ +import matplotlib.pyplot as plt +import numpy as np +import os +import glob +import rlxnix as rlx +from useful_functions import sam_data, sam_spectrum, calculate_integral, contrast_sorting, remove_poor +from tqdm import tqdm # Import tqdm for the progress bar + + +def load_files(file_path_pattern): + """Load all files matching the pattern and remove poor quality files.""" + all_files = glob.glob(file_path_pattern) + good_files = remove_poor(all_files) + return good_files + + +def process_sam_data(sam): + """Process data for a single SAM and return necessary frequencies and powers.""" + _, _, _, _, eodf, nyquist, stim_freq = sam_data(sam) + + # Skip if stim_freq is NaN + if np.isnan(stim_freq): + return None + + # Get power spectrum and frequency index for 1/2 EODf + freq, power = sam_spectrum(sam) + nyquist_idx = np.searchsorted(freq, nyquist) + + # Get frequencies and powers before 1/2 EODf + freqs_before_half_eodf = freq[:nyquist_idx] + powers_before_half_eodf = power[:nyquist_idx] + + # Get peak frequency and power + am_peak_f = freqs_before_half_eodf[np.argmax(powers_before_half_eodf)] + _, _, peak_power = calculate_integral(freq, power, am_peak_f) + + return stim_freq, am_peak_f, peak_power + + +def plot_contrast_data(contrast_dict, file_tag, axs1, axs2): + """Loop over all contrasts and plot AM Frequency and AM Power.""" + for idx, contrast in enumerate(contrast_dict): # contrasts = keys of dict + ax1 = axs1[idx] # First figure (AM Frequency vs Stimulus Frequency) + ax2 = axs2[idx] # Second figure (AM Power vs Stimulus Frequency) + contrast_sams = contrast_dict[contrast] + + # store all stim_freq and peak_power/nyquist_freq for this contrast + stim_freqs = [] + am_freqs = [] + peak_powers = [] + + # loop over all sams of one contrast + for sam in contrast_sams: + processed_data = process_sam_data(sam) + if processed_data is None: + continue + stim_freq, am_peak_f, peak_power = processed_data + stim_freqs.append(stim_freq) + am_freqs.append(am_peak_f) + peak_powers.append(peak_power) + + # Plot in the first figure (AM Frequency vs Stimulus Frequency) + ax1.plot(stim_freqs, am_freqs, '-', label=file_tag) + ax1.set_title(f'Contrast {contrast}%') + ax1.grid(True) + ax1.legend(loc='upper right') + + # Plot in the second figure (AM Power vs Stimulus Frequency) + ax2.plot(stim_freqs, peak_powers, '-', label=file_tag) + ax2.set_title(f'Contrast {contrast}%') + ax2.grid(True) + ax2.legend(loc='upper right') + + +def process_file(file, axs1, axs2): + """Process a single file: extract SAMs and plot data for each contrast.""" + dataset = rlx.Dataset(file) + sam_list = dataset.repro_runs('SAM') + + # Extract the file tag (first part of the filename) for the legend + file_tag = '-'.join(os.path.basename(file).split('-')[0:4]) + + # Sort SAMs by contrast + contrast_dict = contrast_sorting(sam_list) + + # Plot the data for each contrast + plot_contrast_data(contrast_dict, file_tag, axs1, axs2) + + +def loop_over_files(files, axs1, axs2): + """Loop over all good files, process each file, and plot the data.""" + for file in tqdm(files, desc="Processing files"): + process_file(file, axs1, axs2) + + + +def main(): + # Load files + file_path_pattern = '../data/16-10-24/*.nix' + good_files = load_files(file_path_pattern) + + # Initialize figures + fig1, axs1 = plt.subplots(3, 1, constrained_layout=True, sharex=True) # For AM Frequency vs Stimulus Frequency + fig2, axs2 = plt.subplots(3, 1, constrained_layout=True, sharex=True) # For AM Power vs Stimulus Frequency + + # Loop over files and process data + loop_over_files(good_files, axs1, axs2) + + # Add labels to figures + fig1.supxlabel('Stimulus Frequency (df + EODf) [Hz]') + fig1.supylabel('AM Frequency [Hz]') + fig2.supxlabel('Stimulus Frequency (df + EODf) [Hz]') + fig2.supylabel('AM Power') + + # Show plots + plt.show() + + + +# Run the main function +if __name__ == '__main__': + main() + +''' +Function that gets eodf and 1/2 eodf per contrast: + +def calculate_mean_eodf(sams): + """ + Calculate mean EODf and mean 1/2 EODf for the given SAM data. + + Args: + sams (list): List of SAM objects. + + Returns: + mean_eodf (float): Mean EODf across all SAMs. + mean_half_eodf (float): Mean 1/2 EODf (Nyquist frequency) across all SAMs. + """ + eodfs = [] + nyquists = [] + + for sam in sams: + _, _, _, _, eodf, nyquist, _ = sam_data(sam) + + # Add to list only if valid + if not np.isnan(eodf): + eodfs.append(eodf) + nyquists.append(nyquist) + + # Calculate mean EODf and 1/2 EODf + mean_eodf = np.mean(eodfs) + mean_half_eodf = np.mean(nyquists) + + return mean_eodf, mean_half_eodf +''' + +# TODO: + # display eodf values in plot for one cell, one intensity - integrate function for this + # lowpass with gaussian kernel for amplitude plot(0.5 sigma in frequency spectrum (dont filter too narrowly)) + # fix legends (only for the cells that are being displayed) + # save figures + # plot remaining 3 plots, make 1 function for every option and put that in main code + # push files to git