import glob
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import rlxnix as rlx
from IPython import embed
from scipy.signal import welch

def all_coming_together(freq_array, power_array, points_list, categories, num_harmonics_list, colors, delta=2.5, threshold=0.5):
    """
    Process a list of points, calculating integrals, checking validity, and preparing harmonics for valid points.

    Parameters
    ----------
    freq_array : np.array
        Array of frequencies corresponding to the power values.
    power_array : np.array
        Array of power spectral density values.
    points_list : list
        List of harmonic frequency points to process.
    categories : list
        List of corresponding categories for each point.
    num_harmonics_list : list
        List of the number of harmonics for each point.
    colors : list
        List of colors corresponding to each point's category.
    delta : float, optional
        Radius of the range for integration around each point (default is 2.5).
    threshold : float, optional
        Threshold value to compare integrals with local mean (default is 0.5).

    Returns
    -------
    valid_points : list
        A continuous list of harmonics for all valid points.
    color_mapping : dict
        A dictionary mapping categories to corresponding colors.
    category_harmonics : dict
        A mapping of categories to their harmonic frequencies.
    messages : list
        A list of messages for each point, stating whether it was valid or not.
    """
    valid_points = []  # A continuous list of harmonics for valid points
    color_mapping = {}
    category_harmonics = {}
    messages = []

    for i, point in enumerate(points_list):
        category = categories[i]
        num_harmonics = num_harmonics_list[i]
        color = colors[i]
        
        # Step 1: Calculate the integral for the point
        integral, local_mean, _ = calculate_integral(freq_array, power_array, point, delta)
        
        # Step 2: Check if the point is valid
        valid = valid_integrals(integral, local_mean, point, threshold)
        if valid:
            # Step 3: Prepare harmonics if the point is valid
            harmonics, color_map, category_harm = prepare_harmonic(point, category, num_harmonics, color)
            valid_points.extend(harmonics)  # Use extend() to append harmonics in a continuous manner
            color_mapping.update(color_map)
            category_harmonics.update(category_harm)
            messages.append(f"The point {point} is valid.")
        else:
            messages.append(f"The point {point} is not valid.")
    
    return valid_points, color_mapping, category_harmonics, messages



def AM(EODf, stimulus):
    """
    Calculates the Amplitude Modulation and Nyquist frequency

    Parameters
    ----------
    EODf : float or int
        The current EODf.
    stimulus : float or int
        The absolute frequency of the stimulus.

    Returns
    -------
    AM : float 
        The amplitude modulation resulting from the stimulus.
    nyquist : float
        The maximum frequency possible to resolve with the EODf.

    """
    nyquist = EODf * 0.5
    AM = np.mod(stimulus, nyquist)
    return AM, nyquist

def binary_spikes(spike_times, duration, dt):
    """
    Converts the spike times to a binary representations

    Parameters
    ----------
    spike_times : np.array
        The spike times.
    duration : float
        The trial duration:
    dt : float
        The temporal resolution.

    Returns
    -------
    binary : np.array
        The binary representation of the spike train.

    """
    binary = np.zeros(int(np.round(duration / dt))) #create the binary array with the same length as potential
    
    spike_indices = np.asarray(np.round(spike_times / dt), dtype = int) # get the indices
    binary[spike_indices] = 1 # put the indices into binary
    return binary

def calculate_integral(freq, power, point, delta = 2.5):   
    """
    Calculate the integral around a single specified point.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    point : float
        The harmonic frequency at which to calculate the integral.
    delta : float, optional
        Radius of the range for integration around the point. The default is 2.5.

    Returns
    -------
    integral : float
        The calculated integral around the point.
    local_mean : float
        The local mean value (adjacent integrals).
    p_power : float
        The local maxiumum power.
    """
    indices = (freq >= point - delta) & (freq <= point + delta)
    integral = np.trapz(power[indices], freq[indices])
    p_power = np.max(power[indices])
            
    left_indices = (freq >= point - 5 * delta) & (freq < point - delta)
    right_indices = (freq > point + delta) & (freq <= point + 5 * delta)
           
    l_integral = np.trapz(power[left_indices], freq[left_indices])
    r_integral = np.trapz(power[right_indices], freq[right_indices])
           
    local_mean = np.mean([l_integral, r_integral])
    return integral, local_mean, p_power

def extract_stim_data(stimulus):
    '''
    extracts all necessary metadata for each stimulus

    Parameters
    ----------
    stimulus : Stimulus object or rlxnix.base.repro module
        The stimulus from which the data is needed.

    Returns
    -------
    amplitude : float
        The relative signal amplitude in percent.
    df : float
        Distance of the stimulus to the current EODf.
    eodf : float
        Current EODf.
    stim_freq : float
        The total stimulus frequency (EODF+df).
    stim_dur : float
        The stimulus duration.
    amp_mod : float
        The current amplitude modulation.
    ny_freq : float
        The current nyquist frequency.

    '''
    # extract metadata
    # the stim.name adjusts the first key as it changes with every stimulus
    amplitude = stimulus.metadata[stimulus.name]['Contrast'][0][0] 
    df = stimulus.metadata[stimulus.name]['DeltaF'][0][0]
    eodf = round(stimulus.metadata[stimulus.name]['EODf'][0][0])
    stim_freq = round(stimulus.metadata[stimulus.name]['Frequency'][0][0])
    stim_dur = stimulus.duration
    # calculates the amplitude modulation
    amp_mod, ny_freq = AM(eodf, stim_freq)
    return amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq

def find_exceeding_points(frequency, power, points, delta, threshold):
    """
    Find the points where the integral exceeds the local mean by a given threshold.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    points : list
        A list of harmonic frequencies to evaluate.
    delta : float
        Half-width of the range for integration around the point.
    threshold : float
        Threshold value to compare integrals with local mean.

    Returns
    -------
    exceeding_points : list
        A list of points where the integral exceeds the local mean by the threshold.
    """
    exceeding_points = []
    
    for point in points:
        # Calculate the integral and local mean for the current point
        integral, local_mean = calculate_integral(frequency, power, point, delta)
        
        # Check if the integral exceeds the threshold
        valid, message = valid_integrals(integral, local_mean, threshold, point)
        
        if valid:
            exceeding_points.append(point)
    
    return exceeding_points

def firing_rate(binary_spikes, dt = 0.000025, box_width = 0.01):
    '''
    Calculates the firing rate from binary spikes

    Parameters
    ----------
    binary_spikes : np.array
        The binary representation of the spike train.
    dt : float, optional
        Time difference between two datapoints. The default is 0.000025.
    box_width : float, optional
        Time window on which the rate should be computed on. The default is 0.01.

    Returns
    -------
    rate : np.array
        Array of firing rates.

    '''
    box = np.ones(int(box_width // dt))
    box /= np.sum(box) * dt # normalisierung des box kernels to an integral of one
    rate = np.convolve(binary_spikes, box, mode = 'same') 
    return rate

def power_spectrum(stimulus):
    '''
    Computes a power spectrum based from a stimulus

    Parameters
    ----------
    stimulus : Stimulus object or rlxnix.base.repro module
        The stimulus for which the data is needed.

    Returns
    -------
    freq : np.array
        All the frequencies of the power spectrum.
    power : np.array
        Power of the frequencies calculated.

    '''
    spikes, duration, dt = spike_times(stimulus)
    # binarizes spikes
    binary = binary_spikes(spikes, duration, dt)
    # computes firing rates
    rate = firing_rate(binary, dt = dt)
    # creates power spectrum
    freq, power = welch(rate, fs = 1/dt, nperseg = 2**16, noverlap = 2**15)
    return freq, power

def prepare_harmonic(frequency, category, num_harmonics, color):
    """
    Prepare harmonic frequencies and assign color based on category for a single point.

    Parameters
    ----------
    frequency : float
        Base frequency to generate harmonics.
    category : str
        Corresponding category for the base frequency.
    num_harmonics : int
        Number of harmonics for the base frequency.
    color : str
        Color corresponding to the category.

    Returns
    -------
    harmonics : list
        A list of harmonic frequencies.
    color_mapping : dict
        A dictionary mapping the category to its corresponding color.
    category_harmonics : dict
        A mapping of the category to its harmonic frequencies.
    """
    harmonics = [frequency * (i + 1) for i in range(num_harmonics)]
    
    color_mapping = {category: color}
    category_harmonics = {category: harmonics}

    return harmonics, color_mapping, category_harmonics

def remove_poor(files):
    """
    Removes poor datasets from the set of files for analysis

    Parameters
    ----------
    files : list 
        list of files.

    Returns
    -------
    good_files : list
        list of files without the ones with the label poor.

    """
    # create list for good files
    good_files = []
    # loop over files
    for i in range(len(files)):
        # print(files[i])
        # load the file (takes some time)
        data = rlx.Dataset(files[i])
        # get the quality
        quality = str.lower(data.metadata["Recording"]["Recording quality"][0][0])
        # check the quality
        if quality != "poor":
            # if its good or fair add it to the good files
            good_files.append(files[i])
    return good_files

def sam_data(sam):
    '''
    Gets metadata for each SAM

    Parameters
    ----------
    sam : ReproRun object
        The sam the metdata should be extracted from.

    Returns
    -------
    avg_dur : float
        Average stimulus duarion.
    sam_amp : float
        amplitude in percent, relative to the fish amplitude.
    sam_am : float
        Amplitude modulation frequency.
    sam_df : float
        Difference from the stimulus to the current fish eodf.
    sam_eodf : float
        The current EODf.
    sam_nyquist : float
        The Nyquist frequency of the EODf.
    sam_stim : float
        The stimulus frequency.

    '''
    # create lists for the values we want
    amplitudes = []
    dfs = []
    eodfs = []
    stim_freqs = []
    amp_mods = []
    ny_freqs = []
    durations = []
    
    # get the stimuli
    stimuli = sam.stimuli
    
    # loop over the stimuli
    for stim in stimuli:
        amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq = extract_stim_data(stim)
        amplitudes.append(amplitude)
        dfs.append(df)
        eodfs.append(eodf)
        stim_freqs.append(stim_freq)
        amp_mods.append(amp_mod)
        ny_freqs.append(ny_freq)
        durations.append(stim_dur)
      
    # get the means
    sam_amp = np.mean(amplitudes)
    sam_am = np.mean(amp_mods)
    sam_df = np.mean(dfs)
    sam_eodf = np.mean(eodfs)
    sam_nyquist = np.mean(ny_freqs)
    sam_stim = np.mean(stim_freqs)
    avg_dur = np.mean(durations)
    return avg_dur, sam_amp, sam_am, sam_df, sam_eodf, sam_nyquist, sam_stim

def spike_times(stim):
    """
    Reads out the spike times and other necessary parameters

    Parameters
    ----------
    stim : Stimulus object or rlxnix.base.repro module
        The stimulus from which the spike times should be calculated.

    Returns
    -------
    spike_times : np.array
        The spike times of the stimulus.
    stim_dur : float
        The duration of the stimulus.
    dt : float
        Time interval between two data points.

    """
    # reads out the spike times
    spikes, _ = stim.trace_data('Spikes-1')
    # reads out the duration
    stim_dur = stim.duration
    # get the stimulus interval
    ti = stim.trace_info("V-1")
    dt = ti.sampling_interval  
    return spikes, stim_dur, dt # se changed spike_times to spikes so its not the same as name of function


<<<<<<< HEAD
def valid_integrals(integral, local_mean, point, threshold = 0.1):
=======
def valid_integrals(integral, local_mean, point, threshold = 0.3):
>>>>>>> 3575361af10b14a99959f8aeb1b57a1e08d0446a
    """
    Check if the integral exceeds the threshold compared to the local mean and 
    provide feedback on whether the given point is valid or not.

    Parameters
    ----------
    integral : float
        The calculated integral around the point.
    local_mean : float
        The local mean value (adjacent integrals).
    threshold : float
        Threshold value to compare integrals with local mean.
    point : float
        The harmonic frequency point being evaluated.

    Returns
    -------
    valid : bool
        True if the integral exceeds the local mean by the threshold, otherwise False.
    """
    valid = integral > (local_mean * (1 + threshold))
    if valid:
        print(f"The point {point} is valid, as its integral exceeds the threshold.")
    else:
        print(f"The point {point} is not valid, as its integral does not exceed the threshold.")
    return valid

'''TODO Sarah: AM-freq plot:
    meaning of am peak in spectrum? why is it there how does it change with stim intensity? 
    make plot with AM 1/2 EODf over stim frequency (df+eodf), get amplitude of am peak and plot
    amplitude over frequency of peak'''
    
""" files = glob.glob("../data/2024-10-16*.nix") gets all the filepaths from the 16.10"""