import numpy as np
import rlxnix as rlx
from scipy.signal import welch
from scipy import signal
import matplotlib.pyplot as plt

def all_coming_together(freq_array, power_array, points_list, categories, num_harmonics_list, colors, delta=2.5, threshold=0.5):
    # Initialize dictionaries and lists
    valid_points = []
    color_mapping = {}
    category_harmonics = {}
    messages = []

    for i, point in enumerate(points_list):
        category = categories[i]
        num_harmonics = num_harmonics_list[i]
        color = colors[i]
        
        # Calculate the integral for the point
        integral, local_mean = calculate_integral_2(freq_array, power_array, point, delta)
        
        # Check if the point is valid
        valid = valid_integrals(integral, local_mean, point, threshold)
        if valid:
            # Prepare harmonics if the point is valid
            harmonics, color_map, category_harm = prepare_harmonic(point, category, num_harmonics, color)
            valid_points.extend(harmonics)
            color_mapping[category] = color  # Store color for category
            category_harmonics[category] = harmonics
            messages.append(f"The point {point} is valid.")
        else:
            messages.append(f"The point {point} is not valid.")

    # Debugging print statements
    print("Color Mapping:", color_mapping)
    print("Category Harmonics:", category_harmonics)

    return valid_points, color_mapping, category_harmonics, messages



def AM(EODf, stimulus):
    """
    Calculates the Amplitude Modulation and Nyquist frequency

    Parameters
    ----------
    EODf : float or int
        The current EODf.
    stimulus : float or int
        The absolute frequency of the stimulus.

    Returns
    -------
    AM : float 
        The amplitude modulation resulting from the stimulus.
    nyquist : float
        The maximum frequency possible to resolve with the EODf.

    """
    nyquist = EODf * 0.5
    AM = np.mod(stimulus, nyquist)
    return AM, nyquist

def binary_spikes(spike_times, duration, dt):
    """
    Converts the spike times to a binary representations

    Parameters
    ----------
    spike_times : np.array
        The spike times.
    duration : float
        The trial duration:
    dt : float
        The temporal resolution.

    Returns
    -------
    binary : np.array
        The binary representation of the spike train.

    """
    binary = np.zeros(int(np.round(duration / dt))) #create the binary array with the same length as potential
    
    spike_indices = np.asarray(np.round(spike_times / dt), dtype = int) # get the indices
    binary[spike_indices] = 1 # put the indices into binary
    return binary

def calculate_integral(freq, power, point, delta = 2.5):   
    """
    Calculate the integral around a single specified point.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    point : float
        The harmonic frequency at which to calculate the integral.
    delta : float, optional
        Radius of the range for integration around the point. The default is 2.5.

    Returns
    -------
    integral : float
        The calculated integral around the point.
    local_mean : float
        The local mean value (adjacent integrals).
    p_power : float
        The local maxiumum power.
    """
    indices = (freq >= point - delta) & (freq <= point + delta)
    integral = np.trapz(power[indices], freq[indices])
    p_power = np.max(power[indices])
            
    left_indices = (freq >= point - 5 * delta) & (freq < point - delta)
    right_indices = (freq > point + delta) & (freq <= point + 5 * delta)
           
    l_integral = np.trapz(power[left_indices], freq[left_indices])
    r_integral = np.trapz(power[right_indices], freq[right_indices])
           
    local_mean = np.mean([l_integral, r_integral])
    return integral, local_mean, p_power

def calculate_integral_2(freq, power, point, delta = 2.5):   
    """
    Calculate the integral around a single specified point.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    point : float
        The harmonic frequency at which to calculate the integral.
    delta : float, optional
        Radius of the range for integration around the point. The default is 2.5.

    Returns
    -------
    integral : float
        The calculated integral around the point.
    local_mean : float
        The local mean value (adjacent integrals).
    p_power : float
        The local maxiumum power.
    """
    indices = (freq >= point - delta) & (freq <= point + delta)
    integral = np.trapz(power[indices], freq[indices])
            
    left_indices = (freq >= point - 5 * delta) & (freq < point - delta)
    right_indices = (freq > point + delta) & (freq <= point + 5 * delta)
           
    l_integral = np.trapz(power[left_indices], freq[left_indices])
    r_integral = np.trapz(power[right_indices], freq[right_indices])
           
    local_mean = np.mean([l_integral, r_integral])
    return integral, local_mean

def contrast_sorting(sams, con_1 = 20, con_2 = 10, con_3 = 5, stim_count = 3, stim_dur = 2):
    '''
    sorts the sams into three contrasts

    Parameters
    ----------
    sams : ReproRuns
        The sams to be sorted.
    con_1 : int, optional
        the first contrast. The default is 20.
    con_2 : int, optional
        the second contrast. The default is 10.
    con_3 : int, optional
        the third contrast. The default is 5.
    stim_count : int, optional
        the amount of stimuli per sam in a good sam. The default is 3.
    stim_dur : int, optional
        The stimulus duration. The default is 2.

    Returns
    -------
    contrast_sams : dictionary
        A dictionary containing all sams sorted to the contrasts.

    '''
    # dictionary for the contrasts
    contrast_sams = {con_1 : [],
                     con_2 : [],
                     con_3 : []}
    # loop over all sams
    for sam in sams:
        # get the contrast
        avg_dur, contrast, _, _, _, _, _ = sam_data(sam)
        # check for valid trails
        if np.isnan(contrast):
            continue
        elif sam.stimulus_count < stim_count: #aborted trials
            continue
        elif avg_dur < (stim_dur * 0.8):
            continue
        else:
            contrast = int(contrast) # get integer of contrast
            # sort them accordingly
            if contrast == con_1:
                contrast_sams[con_1].append(sam)
            elif contrast == con_2:
                contrast_sams[con_2].append(sam)
            elif contrast == con_3:
                contrast_sams[con_3].append(sam)
            else:
                continue
    return contrast_sams

def extract_stim_data(stimulus):
    '''
    extracts all necessary metadata for each stimulus

    Parameters
    ----------
    stimulus : Stimulus object or rlxnix.base.repro module
        The stimulus from which the data is needed.

    Returns
    -------
    amplitude : float
        The relative signal amplitude in percent.
    df : float
        Distance of the stimulus to the current EODf.
    eodf : float
        Current EODf.
    stim_freq : float
        The total stimulus frequency (EODF+df).
    stim_dur : float
        The stimulus duration.
    amp_mod : float
        The current amplitude modulation.
    ny_freq : float
        The current nyquist frequency.

    '''
    # extract metadata
    # the stim.name adjusts the first key as it changes with every stimulus
    amplitude = stimulus.metadata[stimulus.name]['Contrast'][0][0] 
    df = stimulus.metadata[stimulus.name]['DeltaF'][0][0]
    eodf = round(stimulus.metadata[stimulus.name]['EODf'][0][0])
    stim_freq = round(stimulus.metadata[stimulus.name]['Frequency'][0][0])
    stim_dur = stimulus.duration
    # calculates the amplitude modulation
    _, ny_freq = AM(eodf, stim_freq)
    amp_mod = find_AM(eodf, ny_freq, stim_freq)
    return amplitude, df, eodf, stim_freq, stim_dur, amp_mod, ny_freq

def find_AM(eodf, nyquist, stimulus_frequency):
    t = signal.windows.triang(eodf) * nyquist
    length_t2 = int(eodf*10)
    t2 = np.tile(t, length_t2)
    x_values = np.arange(len(t2))
    
    #fig, ax = plt.subplots()
    #ax.plot(t2)
    #ax.scatter(stimulus_frequency, t2[np.argmin(np.abs(x_values - stimulus_frequency))])
    #plt.grid()

    AM = t2[np.argmin(np.abs(x_values - stimulus_frequency))]
    return AM

def firing_rate(binary_spikes, dt = 0.000025, box_width = 0.01):
    '''
    Calculates the firing rate from binary spikes

    Parameters
    ----------
    binary_spikes : np.array
        The binary representation of the spike train.
    dt : float, optional
        Time difference between two datapoints. The default is 0.000025.
    box_width : float, optional
        Time window on which the rate should be computed on. The default is 0.01.

    Returns
    -------
    rate : np.array
        Array of firing rates.

    '''
    box = np.ones(int(box_width // dt))
    box /= np.sum(box) * dt # normalisierung des box kernels to an integral of one
    rate = np.convolve(binary_spikes, box, mode = 'same') 
    return rate

def power_spectrum(stimulus):
    '''
    Computes a power spectrum based from a stimulus

    Parameters
    ----------
    stimulus : Stimulus object or rlxnix.base.repro module
        The stimulus for which the data is needed.

    Returns
    -------
    freq : np.array
        All the frequencies of the power spectrum.
    power : np.array
        Power of the frequencies calculated.

    '''
    spikes, duration, dt = spike_times(stimulus)
    # binarizes spikes
    binary = binary_spikes(spikes, duration, dt)
    # computes firing rates
    rate = firing_rate(binary, dt = dt)
    # creates power spectrum
    freq, power = welch(binary, fs = 1/dt, nperseg = 2**16, noverlap = 2**15)
    return freq, power

def prepare_harmonic(frequency, category, num_harmonics, color):
    """
    Prepare harmonic frequencies and assign color based on category for a single point.

    Parameters
    ----------
    frequency : float
        Base frequency to generate harmonics.
    category : str
        Corresponding category for the base frequency.
    num_harmonics : int
        Number of harmonics for the base frequency.
    color : str
        Color corresponding to the category.

    Returns
    -------
    harmonics : list
        A list of harmonic frequencies.
    color_mapping : dict
        A dictionary mapping the category to its corresponding color.
    category_harmonics : dict
        A mapping of the category to its harmonic frequencies.
    """
    harmonics = [frequency * (i + 1) for i in range(num_harmonics)]
    
    color_mapping = {category: color}
    category_harmonics = {category: harmonics}

    return harmonics, color_mapping, category_harmonics

def remove_poor(files):
    """
    Removes poor datasets from the set of files for analysis

    Parameters
    ----------
    files : list 
        list of files.

    Returns
    -------
    good_files : list
        list of files without the ones with the label poor.

    """
    # create list for good files
    good_files = []
    # loop over files
    for i in range(len(files)):
        # print(files[i])
        # load the file (takes some time)
        data = rlx.Dataset(files[i])
        # get the quality
        quality = str.lower(data.metadata["Recording"]["Recording quality"][0][0])
        # check the quality
        if quality != "poor":
            # if its good or fair add it to the good files
            good_files.append(files[i])
    return good_files

def sam_data(sam):
    '''
    Gets metadata for each SAM

    Parameters
    ----------
    sam : ReproRun object
        The sam the metdata should be extracted from.

    Returns
    -------
    avg_dur : float
        Average stimulus duarion.
    sam_amp : float
        amplitude in percent, relative to the fish amplitude.
    sam_am : float
        Amplitude modulation frequency.
    sam_df : float
        Difference from the stimulus to the current fish eodf.
    sam_eodf : float
        The current EODf.
    sam_nyquist : float
        The Nyquist frequency of the EODf.
    sam_stim : float
        The stimulus frequency.

    '''
    # create lists for the values we want
    amplitudes = []
    dfs = []
    eodfs = []
    stim_freqs = []
    amp_mods = []
    ny_freqs = []
    durations = []
    
    # get the stimuli
    stimuli = sam.stimuli
    
    # loop over the stimuli
    for stim in stimuli:
        amplitude, df, eodf, stim_freq,stim_dur, amp_mod, ny_freq = extract_stim_data(stim)
        amplitudes.append(amplitude)
        dfs.append(df)
        eodfs.append(eodf)
        stim_freqs.append(stim_freq)
        amp_mods.append(amp_mod)
        ny_freqs.append(ny_freq)
        durations.append(stim_dur)
      
    # get the means
    sam_amp = np.mean(amplitudes)
    sam_am = np.mean(amp_mods)
    sam_df = np.mean(dfs)
    sam_eodf = np.mean(eodfs)
    sam_nyquist = np.mean(ny_freqs)
    sam_stim = np.mean(stim_freqs)
    avg_dur = np.mean(durations)
    return avg_dur, sam_amp, sam_am, sam_df, sam_eodf, sam_nyquist, sam_stim

def sam_spectrum(sam):
    """
    Creates a power spectrum for a ReproRun of a SAM.

    Parameters
    ----------
    sam : ReproRun Object
        The Reprorun the powerspectrum should be generated from.

    Returns
    -------
    sam_frequency : np.array
        The frequencies of the powerspectrum.
    sam_power : np.array
        The powers of the frequencies.

    """
    stimuli = sam.stimuli
    # lists for the power spectra
    frequencies = []
    powers = []
    # loop over the stimuli
    for stimulus in stimuli:
        # get the powerspectrum for each stimuli
        frequency, power = power_spectrum(stimulus)
        # append the power spectrum data
        frequencies.append(frequency)
        powers.append(power)
    #average over the stimuli
    sam_frequency = np.mean(frequencies, axis = 0)
    sam_power = np.mean(powers, axis = 0)
    return sam_frequency, sam_power

def spike_times(stim):
    """
    Reads out the spike times and other necessary parameters

    Parameters
    ----------
    stim : Stimulus object or rlxnix.base.repro module
        The stimulus from which the spike times should be calculated.

    Returns
    -------
    spike_times : np.array
        The spike times of the stimulus.
    stim_dur : float
        The duration of the stimulus.
    dt : float
        Time interval between two data points.

    """
    # reads out the spike times
    spikes, _ = stim.trace_data('Spikes-1')
    # reads out the duration
    stim_dur = stim.duration
    # get the stimulus interval
    ti = stim.trace_info("V-1")
    dt = ti.sampling_interval  
    return spikes, stim_dur, dt # se changed spike_times to spikes so its not the same as name of function

def true_eodf(eodf_file):
    '''
    Calculates the Eodf of the fish when it was awake from a nix file.

    Parameters
    ----------
    eodf_file : str
        path to the file with nix-file for the eodf.

    Returns
    -------
    orig_eodf : int
        The original eodf.

    '''
    eod_data = rlx.Dataset(eodf_file)#load eodf file
    baseline = eod_data.repro_runs('baseline')[0]
    eod, time = baseline.trace_data('EOD') # get time and eod
    dt = baseline.trace_info('EOD').sampling_interval
    eod_freq, eod_power = welch(eod, fs = 1/dt, nperseg = 2**16, noverlap = 2**15)
    orig_eodf = round(eod_freq[np.argmax(eod_power)])
    return orig_eodf

def valid_integrals(integral, local_mean, point, threshold = 0.1):
    """
    Check if the integral exceeds the threshold compared to the local mean and 
    provide feedback on whether the given point is valid or not.

    Parameters
    ----------
    integral : float
        The calculated integral around the point.
    local_mean : float
        The local mean value (adjacent integrals).
    threshold : float
        Threshold value to compare integrals with local mean.
    point : float
        The harmonic frequency point being evaluated.

    Returns
    -------
    valid : bool
        True if the integral exceeds the local mean by the threshold, otherwise False.
    """
    valid = integral > (local_mean * (1 + threshold))
    if valid:
        print(f"The point {point} is valid.")
    else:
        print(f"The point {point} is not valid.")
    return valid

'''TODO Sarah: AM-freq plot:
    meaning of am peak in spectrum? why is it there how does it change with stim intensity? 
    make plot with AM 1/2 EODf over stim frequency (df+eodf), get amplitude of am peak and plot
    amplitude over frequency of peak'''
    
""" files = glob.glob("../data/2024-10-16*.nix") gets all the filepaths from the 16.10"""