import matplotlib.pyplot as plt
import numpy as np
import os
import rlxnix as rlx
from useful_functions import sam_data, sam_spectrum, calculate_integral, contrast_sorting

# close all open plots
plt.close('all')

def plot_am_vs_frequency_single_intensity(file, contrast=20):
    """
    Plots AM Power vs Stimulus Frequency and Nyquist Frequency vs Stimulus Frequency for 
    one intensity and one cell (file).
    
    Parameters:
        file (str): Path to the file (one cell).
        intensity (int): The intensity level (contrast) to filter by.
    """
    # Load the dataset for the given file
    dataset = rlx.Dataset(file)
    
    # Get SAMs for the whole recording
    sam_list = dataset.repro_runs('SAM')
    
    # Extract the file tag (first part of the filename) for the legend
    file_tag = '-'.join(os.path.basename(file).split('-')[0:4])
    
    # Sort SAMs by contrast
    contrast_dict = contrast_sorting(sam_list)
    
    # Get the SAMs for 20% contrast
    sams = contrast_dict[contrast]
    
    # Create a figure with 1 row and 2 columns
    fig, axs = plt.subplots(2, 1, layout='constrained')
    
    # Store all stim_freq, peak_power, and am_freq for the given contrast
    stim_freqs = []
    peak_powers = []
    am_freqs = []
    
    # Loop over all SAMs of the specified contrast
    for sam in sams:
        
        # Get stim_freq for each SAM
        _, _, _, _, eodf, nyquist, stim_freq = sam_data(sam)
        
        # Skip over empty SAMs
        if np.isnan(stim_freq):
            continue
        
        # Get power spectrum from one SAM
        freq, power = sam_spectrum(sam)
        
        # get index of 1/2 eodf frequency
        nyquist_idx = np.searchsorted(freq, nyquist)
        
        # get frequencies until 1/2 eodf and powers for those frequencies
        freqs_before_half_eodf = freq[:nyquist_idx]
        powers_before_half_eodf = power[:nyquist_idx]
    
        # Get the frequency of the highest peak before 1/2 EODf
        am_peak_f = freqs_before_half_eodf[np.argmax(powers_before_half_eodf)]
        
        # Get the power of the highest peak before 1/2 EODf
        _, _, peak_power = calculate_integral(freq, power, am_peak_f)
        
        # Collect data for plotting
        stim_freqs.append(stim_freq)
        peak_powers.append(peak_power)
        am_freqs.append(am_peak_f)
    
    # Plot AM Power vs Stimulus Frequency (first column)
    ax = axs[0]
    ax.plot(stim_freqs, am_freqs, '-')
    ax.set_ylabel('AM Frequency [Hz]')
    ax.grid(True)
    
    # Plot AM Frequency vs Stimulus Frequency (second column)
    ax = axs[1]
    ax.plot(stim_freqs, peak_powers, '-')
    ax.set_ylabel('AM Power')
    ax.grid(True)
    
    # Figure settings
    fig.suptitle(f"Cell: {file_tag}, Contrast: {contrast}%")
    fig.supxlabel("Stimulus Frequency (df + EODf) [Hz]")
    plt.show()


# Call function
file = '../data/16-10-24/2024-10-16-ad-invivo-1.nix'

# Call the function to plot the data for one intensity and one cell
plot_am_vs_frequency_single_intensity(file)