import glob
import matplotlib.pyplot as plt
import numpy as np
import rlxnix as rlx
import scipy as sp
import time
import useful_functions as f

# tatsächliche Power der peaks benutzen



# variables
delta = 2.5 # radius for peak detection

# all files we want to use
files = glob.glob("../data/2024-10-16-af*.nix")

# get only the good and fair filepaths
new_files = f.remove_poor(files)


# loop over all the good files
for file in new_files:
    # load a file
    dataset = rlx.Dataset(file)
    # extract sams
    sams = dataset.repro_runs('SAM')
    # get arrays for frequnecies and power
    stim_frequencies = np.zeros(len(sams))
    peak_powers = np.zeros_like(stim_frequencies)
    # loop over all sams
    for i, sam in enumerate(sams):
        # get sam frequency and stimuli
        avg_dur, _, _, _, _, _, stim_frequency = f.sam_data(sam)
        print(avg_dur)
        if np.isnan(avg_dur):
            continue
        # use this to change lists basically and add the contrast somewhere
        else:
            stimuli = sam.stimuli
            # lists for the power spectra
            frequencies = []
            powers = []
            # loop over the stimuli
            for stimulus in stimuli:
                # get the powerspectrum for each stimuli
                frequency, power = f.power_spectrum(stimulus)
                # append the power spectrum data
                frequencies.append(frequency)
                powers.append(power)
            #average over the stimuli
            sam_frequency = np.mean(frequencies, axis = 0)
            sam_power = np.mean(powers, axis = 0)
            # detect and validate peaks
            integral, surroundings, peak_power = f.calculate_integral(sam_frequency, 
                                                          sam_power, stim_frequency)
            valid = f.valid_integrals(integral, surroundings, stim_frequency)
            #if there is a peak get the power in the peak powers
            if valid == True:
                peak_powers[i] = peak_power
            # add the current stimulus frequency
            stim_frequencies[i] = stim_frequency
            
    # replae zeros with NaN
    peak_powers = np.where(peak_powers == 0, np.nan, peak_powers)
    
plt.plot(stim_frequencies, peak_powers)