import glob
import matplotlib.pyplot as plt
import numpy as np
import rlxnix as rlx
import scipy as sp
import time
import useful_functions as f

# tatsächliche Power der peaks benutzen




# all files we want to use
files = glob.glob("../data/2024-10-*.nix")

# get only the good and fair filepaths
new_files = f.remove_poor(files)


# loop over all the good files
for file in new_files:
    
    contrast_frequencies = []
    contrast_powers = []
    # load a file
    dataset = rlx.Dataset(file)
    # extract sams
    sams = dataset.repro_runs('SAM')
    # get arrays for frequnecies and power
    stim_frequencies = np.zeros(len(sams))
    peak_powers = np.zeros_like(stim_frequencies)
    # loop over all sams
    # dictionary for the contrasts
    contrast_sams = {20 : [],
                     10 : [],
                     5 : []}
    # loop over all sams
    for sam in sams:
        # get the contrast
        avg_dur, contrast, _, _, _, _, _ = f.sam_data(sam)
        # check for valid trails
        if np.isnan(contrast):
            continue
        elif sam.stimulus_count < 3: #aborted trials
            continue
        elif avg_dur < 1.7:
            continue
        else:
            contrast = int(contrast) # get integer of contrast
            # sort them accordingly
            if contrast == 20:
                contrast_sams[20].append(sam)
            if contrast == 10:
                contrast_sams[10].append(sam)
            if contrast == 5:
                contrast_sams[5].append(sam)
            else:
                continue
    # loop over the contrasts
    for key in contrast_sams:
        stim_frequencies = np.zeros(len(contrast_sams[key]))
        peak_powers = np.zeros_like(stim_frequencies)
        
        for i, sam in enumerate(contrast_sams[key]):
            # get stimulus frequency and stimuli
            _, _, _, _, _, _, stim_frequency = f.sam_data(sam)
            stimuli = sam.stimuli
            # lists for the power spectra
            frequencies = []
            powers = []
            # loop over the stimuli
            for stimulus in stimuli:
                # get the powerspectrum for each stimuli
                frequency, power = f.power_spectrum(stimulus)
                # append the power spectrum data
                frequencies.append(frequency)
                powers.append(power)
            #average over the stimuli
            sam_frequency = np.mean(frequencies, axis = 0)
            sam_power = np.mean(powers, axis = 0)
            # detect peaks
            integral, surroundings, peak_power = f.calculate_integral(sam_frequency, 
                                                          sam_power, stim_frequency)
            
            peak_powers[i] = peak_power
            # add the current stimulus frequency
            stim_frequencies[i] = stim_frequency
        
        # replae zeros with NaN
        peak_powers = np.where(peak_powers == 0, np.nan, peak_powers)
        
        contrast_frequencies.append(stim_frequencies)
        contrast_powers.append(peak_powers)
            
    fig, ax = plt.subplots(layout = 'constrained')
    ax.plot(contrast_frequencies[0], contrast_powers[0])
    ax.plot(contrast_frequencies[1], contrast_powers[1])
    ax.plot(contrast_frequencies[2], contrast_powers[2])
    ax.set_xlabel('stimulus frequency [Hz]')
    ax.set_ylabel(r' power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]')
    ax.set_title(f"{file}")