import matplotlib.pyplot as plt
import numpy as np
import os
import glob
import rlxnix as rlx
from useful_functions import sam_data, sam_spectrum, calculate_integral, contrast_sorting, remove_poor
from tqdm import tqdm  # Import tqdm for the progress bar


def load_files(file_path_pattern):
    """Load all files matching the pattern and remove poor quality files."""
    all_files = glob.glob(file_path_pattern)
    good_files = remove_poor(all_files)
    return good_files


def process_sam_data(sam):
    """Process data for a single SAM and return necessary frequencies and powers."""
    _, _, _, _, eodf, nyquist, stim_freq = sam_data(sam)
    
    # Skip if stim_freq is NaN
    if np.isnan(stim_freq):
        return None
    
    # Get power spectrum and frequency index for 1/2 EODf
    freq, power = sam_spectrum(sam)
    nyquist_idx = np.searchsorted(freq, nyquist)
    
    # Get frequencies and powers before 1/2 EODf
    freqs_before_half_eodf = freq[:nyquist_idx]
    powers_before_half_eodf = power[:nyquist_idx]
    
    # Get peak frequency and power
    am_peak_f = freqs_before_half_eodf[np.argmax(powers_before_half_eodf)]
    _, _, peak_power = calculate_integral(freq, power, am_peak_f)
    
    return stim_freq, am_peak_f, peak_power


def plot_contrast_data(contrast_dict, file_tag, axs1, axs2):
    """Loop over all contrasts and plot AM Frequency and AM Power."""
    for idx, contrast in enumerate(contrast_dict):  # contrasts = keys of dict
        ax1 = axs1[idx]  # First figure (AM Frequency vs Stimulus Frequency)
        ax2 = axs2[idx]  # Second figure (AM Power vs Stimulus Frequency)
        contrast_sams = contrast_dict[contrast]

        # store all stim_freq and peak_power/nyquist_freq for this contrast
        stim_freqs = []
        am_freqs = []
        peak_powers = []

        # loop over all sams of one contrast
        for sam in contrast_sams:
            processed_data = process_sam_data(sam)
            if processed_data is None:
                continue
            stim_freq, am_peak_f, peak_power = processed_data
            stim_freqs.append(stim_freq)
            am_freqs.append(am_peak_f)
            peak_powers.append(peak_power)

        # Plot in the first figure (AM Frequency vs Stimulus Frequency)
        ax1.plot(stim_freqs, am_freqs, '-', label=file_tag)
        ax1.set_title(f'Contrast {contrast}%')
        ax1.grid(True)
        ax1.legend(loc='upper right')

        # Plot in the second figure (AM Power vs Stimulus Frequency)
        ax2.plot(stim_freqs, peak_powers, '-', label=file_tag)
        ax2.set_title(f'Contrast {contrast}%')
        ax2.grid(True)
        ax2.legend(loc='upper right')


def process_file(file, axs1, axs2):
    """Process a single file: extract SAMs and plot data for each contrast."""
    dataset = rlx.Dataset(file)
    sam_list = dataset.repro_runs('SAM')
    
    # Extract the file tag (first part of the filename) for the legend
    file_tag = '-'.join(os.path.basename(file).split('-')[0:4])
    
    # Sort SAMs by contrast
    contrast_dict = contrast_sorting(sam_list)
    
    # Plot the data for each contrast
    plot_contrast_data(contrast_dict, file_tag, axs1, axs2)


def loop_over_files(files, axs1, axs2):
    """Loop over all good files, process each file, and plot the data."""
    for file in tqdm(files, desc="Processing files"):
        process_file(file, axs1, axs2)



def main():
    # Load files
    file_path_pattern = '../data/16-10-24/*.nix'
    good_files = load_files(file_path_pattern)

    # Initialize figures
    fig1, axs1 = plt.subplots(3, 1, constrained_layout=True, sharex=True)  # For AM Frequency vs Stimulus Frequency
    fig2, axs2 = plt.subplots(3, 1, constrained_layout=True, sharex=True)  # For AM Power vs Stimulus Frequency

    # Loop over files and process data
    loop_over_files(good_files, axs1, axs2)

    # Add labels to figures
    fig1.supxlabel('Stimulus Frequency (df + EODf) [Hz]')
    fig1.supylabel('AM Frequency [Hz]')
    fig2.supxlabel('Stimulus Frequency (df + EODf) [Hz]')
    fig2.supylabel('AM Power')
    
    # Show plots
    plt.show()
    


# Run the main function
if __name__ == '__main__':
    main()
    
'''
Function that gets eodf and 1/2 eodf per contrast:

def calculate_mean_eodf(sams):
    """
    Calculate mean EODf and mean 1/2 EODf for the given SAM data.
    
    Args:
        sams (list): List of SAM objects.
    
    Returns:
        mean_eodf (float): Mean EODf across all SAMs.
        mean_half_eodf (float): Mean 1/2 EODf (Nyquist frequency) across all SAMs.
    """
    eodfs = []
    nyquists = []
    
    for sam in sams:
        _, _, _, _, eodf, nyquist, _ = sam_data(sam)
        
        # Add to list only if valid
        if not np.isnan(eodf):
            eodfs.append(eodf)
            nyquists.append(nyquist)
    
    # Calculate mean EODf and 1/2 EODf
    mean_eodf = np.mean(eodfs)
    mean_half_eodf = np.mean(nyquists)
    
    return mean_eodf, mean_half_eodf
'''

# TODO:
    # display eodf values in plot for one cell, one intensity - integrate function for this
    # lowpass with gaussian kernel for amplitude plot(0.5 sigma in frequency spectrum (dont filter too narrowly))
    # fix legends (only for the cells that are being displayed)
    # save figures
    # plot remaining 3 plots, make 1 function for every option and put that in main code
    # push files to git