From 3ea0083f4cda3f1d8a0f24f48c5c7a0fa6352f3b Mon Sep 17 00:00:00 2001
From: "sarah.eisele" <sarah.eisele@student.uni-tuebingen.de>
Date: Fri, 25 Oct 2024 15:44:08 +0200
Subject: [PATCH] new code for am plots

---
 code/am_plots_modularized.py | 162 +++++++++++++++++++++++++++++++++++
 1 file changed, 162 insertions(+)
 create mode 100644 code/am_plots_modularized.py

diff --git a/code/am_plots_modularized.py b/code/am_plots_modularized.py
new file mode 100644
index 0000000..d56eaef
--- /dev/null
+++ b/code/am_plots_modularized.py
@@ -0,0 +1,162 @@
+import matplotlib.pyplot as plt
+import numpy as np
+import os
+import glob
+import rlxnix as rlx
+from useful_functions import sam_data, sam_spectrum, calculate_integral, contrast_sorting, remove_poor
+from tqdm import tqdm  # Import tqdm for the progress bar
+
+
+def load_files(file_path_pattern):
+    """Load all files matching the pattern and remove poor quality files."""
+    all_files = glob.glob(file_path_pattern)
+    good_files = remove_poor(all_files)
+    return good_files
+
+
+def process_sam_data(sam):
+    """Process data for a single SAM and return necessary frequencies and powers."""
+    _, _, _, _, eodf, nyquist, stim_freq = sam_data(sam)
+    
+    # Skip if stim_freq is NaN
+    if np.isnan(stim_freq):
+        return None
+    
+    # Get power spectrum and frequency index for 1/2 EODf
+    freq, power = sam_spectrum(sam)
+    nyquist_idx = np.searchsorted(freq, nyquist)
+    
+    # Get frequencies and powers before 1/2 EODf
+    freqs_before_half_eodf = freq[:nyquist_idx]
+    powers_before_half_eodf = power[:nyquist_idx]
+    
+    # Get peak frequency and power
+    am_peak_f = freqs_before_half_eodf[np.argmax(powers_before_half_eodf)]
+    _, _, peak_power = calculate_integral(freq, power, am_peak_f)
+    
+    return stim_freq, am_peak_f, peak_power
+
+
+def plot_contrast_data(contrast_dict, file_tag, axs1, axs2):
+    """Loop over all contrasts and plot AM Frequency and AM Power."""
+    for idx, contrast in enumerate(contrast_dict):  # contrasts = keys of dict
+        ax1 = axs1[idx]  # First figure (AM Frequency vs Stimulus Frequency)
+        ax2 = axs2[idx]  # Second figure (AM Power vs Stimulus Frequency)
+        contrast_sams = contrast_dict[contrast]
+
+        # store all stim_freq and peak_power/nyquist_freq for this contrast
+        stim_freqs = []
+        am_freqs = []
+        peak_powers = []
+
+        # loop over all sams of one contrast
+        for sam in contrast_sams:
+            processed_data = process_sam_data(sam)
+            if processed_data is None:
+                continue
+            stim_freq, am_peak_f, peak_power = processed_data
+            stim_freqs.append(stim_freq)
+            am_freqs.append(am_peak_f)
+            peak_powers.append(peak_power)
+
+        # Plot in the first figure (AM Frequency vs Stimulus Frequency)
+        ax1.plot(stim_freqs, am_freqs, '-', label=file_tag)
+        ax1.set_title(f'Contrast {contrast}%')
+        ax1.grid(True)
+        ax1.legend(loc='upper right')
+
+        # Plot in the second figure (AM Power vs Stimulus Frequency)
+        ax2.plot(stim_freqs, peak_powers, '-', label=file_tag)
+        ax2.set_title(f'Contrast {contrast}%')
+        ax2.grid(True)
+        ax2.legend(loc='upper right')
+
+
+def process_file(file, axs1, axs2):
+    """Process a single file: extract SAMs and plot data for each contrast."""
+    dataset = rlx.Dataset(file)
+    sam_list = dataset.repro_runs('SAM')
+    
+    # Extract the file tag (first part of the filename) for the legend
+    file_tag = '-'.join(os.path.basename(file).split('-')[0:4])
+    
+    # Sort SAMs by contrast
+    contrast_dict = contrast_sorting(sam_list)
+    
+    # Plot the data for each contrast
+    plot_contrast_data(contrast_dict, file_tag, axs1, axs2)
+
+
+def loop_over_files(files, axs1, axs2):
+    """Loop over all good files, process each file, and plot the data."""
+    for file in tqdm(files, desc="Processing files"):
+        process_file(file, axs1, axs2)
+
+
+
+def main():
+    # Load files
+    file_path_pattern = '../data/16-10-24/*.nix'
+    good_files = load_files(file_path_pattern)
+
+    # Initialize figures
+    fig1, axs1 = plt.subplots(3, 1, constrained_layout=True, sharex=True)  # For AM Frequency vs Stimulus Frequency
+    fig2, axs2 = plt.subplots(3, 1, constrained_layout=True, sharex=True)  # For AM Power vs Stimulus Frequency
+
+    # Loop over files and process data
+    loop_over_files(good_files, axs1, axs2)
+
+    # Add labels to figures
+    fig1.supxlabel('Stimulus Frequency (df + EODf) [Hz]')
+    fig1.supylabel('AM Frequency [Hz]')
+    fig2.supxlabel('Stimulus Frequency (df + EODf) [Hz]')
+    fig2.supylabel('AM Power')
+    
+    # Show plots
+    plt.show()
+    
+
+
+# Run the main function
+if __name__ == '__main__':
+    main()
+    
+'''
+Function that gets eodf and 1/2 eodf per contrast:
+
+def calculate_mean_eodf(sams):
+    """
+    Calculate mean EODf and mean 1/2 EODf for the given SAM data.
+    
+    Args:
+        sams (list): List of SAM objects.
+    
+    Returns:
+        mean_eodf (float): Mean EODf across all SAMs.
+        mean_half_eodf (float): Mean 1/2 EODf (Nyquist frequency) across all SAMs.
+    """
+    eodfs = []
+    nyquists = []
+    
+    for sam in sams:
+        _, _, _, _, eodf, nyquist, _ = sam_data(sam)
+        
+        # Add to list only if valid
+        if not np.isnan(eodf):
+            eodfs.append(eodf)
+            nyquists.append(nyquist)
+    
+    # Calculate mean EODf and 1/2 EODf
+    mean_eodf = np.mean(eodfs)
+    mean_half_eodf = np.mean(nyquists)
+    
+    return mean_eodf, mean_half_eodf
+'''
+
+# TODO:
+    # display eodf values in plot for one cell, one intensity - integrate function for this
+    # lowpass with gaussian kernel for amplitude plot(0.5 sigma in frequency spectrum (dont filter too narrowly))
+    # fix legends (only for the cells that are being displayed)
+    # save figures
+    # plot remaining 3 plots, make 1 function for every option and put that in main code
+    # push files to git