import glob
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import rlxnix as rlx
from IPython import embed
from scipy.signal import welch

def AM(EODf, stimulus):
    """
    Calculates the Amplitude Modulation and Nyquist frequency

    Parameters
    ----------
    EODf : float or int
        The current EODf.
    stimulus : float or int
        The absolute frequency of the stimulus.

    Returns
    -------
    AM : float 
        The amplitude modulation resulting from the stimulus.
    nyquist : float
        The maximum frequency possible to resolve with the EODf.

    """
    nyquist = EODf * 0.5
    AM = np.mod(stimulus, nyquist)
    return AM, nyquist

def binary_spikes(spike_times, duration, dt):
    """
    Converts the spike times to a binary representations

    Parameters
    ----------
    spike_times : np.array
        The spike times.
    duration : float
        The trial duration:
    dt : float
        The temporal resolution.

    Returns
    -------
    binary : np.array
        The binary representation of the spike train.

    """
    binary = np.zeros(int(np.round(duration / dt))) #create the binary array with the same length as potential
    
    spike_indices = np.asarray(np.round(spike_times / dt), dtype = int) # get the indices
    binary[spike_indices] = 1 # put the indices into binary
    return binary

def calculate_integral(freq, power, point, delta):   
    """
    Calculate the integral around a single specified point.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    point : float
        The harmonic frequency at which to calculate the integral.
    delta : float
        Radius of the range for integration around the point.

    Returns
    -------
    integral : float
        The calculated integral around the point.
    local_mean : float
        The local mean value (adjacent integrals).
    """
    indices = (freq >= point - delta) & (freq <= point + delta)
    integral = np.trapz(power[indices], freq[indices])
            
    left_indices = (freq >= point - 5 * delta) & (freq < point - delta)
    right_indices = (freq > point + delta) & (freq <= point + 5 * delta)
           
    l_integral = np.trapz(power[left_indices], freq[left_indices])
    r_integral = np.trapz(power[right_indices], freq[right_indices])
           
    local_mean = np.mean([l_integral, r_integral])
    return integral, local_mean

def extract_stim_data(stimulus):
    '''
    extracts all necessary metadata for each stimulus

    Parameters
    ----------
    stimulus : Stimulus object or rlxnix.base.repro module
        The stimulus from which the data is needed.

    Returns
    -------
    amplitude : float
        The relative signal amplitude in percent.
    df : float
        Distance of the stimulus to the current EODf.
    eodf : float
        Current EODf.
    stim_freq : float
        The total stimulus frequency (EODF+df).
    amp_mod : float
        The current amplitude modulation.
    ny_freq : float
        The current nyquist frequency.

    '''
    # extract metadata
    # the stim.name adjusts the first key as it changes with every stimulus
    amplitude = stimulus.metadata[stimulus.name]['Contrast'][0][0] 
    df = stimulus.metadata[stimulus.name]['DeltaF'][0][0]
    eodf = round(stimulus.metadata[stimulus.name]['EODf'][0][0])
    stim_freq = round(stimulus.metadata[stimulus.name]['Frequency'][0][0])
    # calculates the amplitude modulation
    amp_mod, ny_freq = AM(eodf, stim_freq)
    return amplitude, df, eodf, stim_freq, amp_mod, ny_freq

def firing_rate(binary_spikes, dt = 0.000025, box_width = 0.01):
    '''
    Calculates the firing rate from binary spikes

    Parameters
    ----------
    binary_spikes : np.array
        The binary representation of the spike train.
    dt : float, optional
        Time difference between two datapoints. The default is 0.000025.
    box_width : float, optional
        Time window on which the rate should be computed on. The default is 0.01.

    Returns
    -------
    rate : np.array
        Array of firing rates.

    '''
    box = np.ones(int(box_width // dt))
    box /= np.sum(box) * dt # normalisierung des box kernels to an integral of one
    rate = np.convolve(binary_spikes, box, mode = 'same') 
    return rate

def power_spectrum(stimulus):
    '''
    Computes a power spectrum based from a stimulus

    Parameters
    ----------
    stimulus : Stimulus object or rlxnix.base.repro module
        The stimulus for which the data is needed.

    Returns
    -------
    freq : np.array
        All the frequencies of the power spectrum.
    power : np.array
        Power of the frequencies calculated.

    '''
    spikes, duration, dt = spike_times(stimulus)
    # binarizes spikes
    binary = binary_spikes(spikes, duration, dt)
    # computes firing rates
    rate = firing_rate(binary, dt = dt)
    # creates power spectrum
    freq, power = welch(rate, fs = 1/dt, nperseg = 2**16, noverlap = 2**15)
    return freq, power

def remove_poor(files):
    """
    Removes poor datasets from the set of files for analysis

    Parameters
    ----------
    files : list 
        list of files.

    Returns
    -------
    good_files : list
        list of files without the ones with the label poor.

    """
    # create list for good files
    good_files = []
    # loop over files
    for i in range(len(files)):
        # print(files[i])
        # load the file (takes some time)
        data = rlx.Dataset(files[i])
        # get the quality
        quality = str.lower(data.metadata["Recording"]["Recording quality"][0][0])
        # check the quality
        if quality != "poor":
            # if its good or fair add it to the good files
            good_files.append(files[i])
    return good_files

def sam_data(sam):
    '''
    Gets metadata for each SAM

    Parameters
    ----------
    sam : ReproRun object
        The sam the metdata should be extracted from.

    Returns
    -------
    sam_amp : float
        amplitude in percent, relative to the fish amplitude.
    sam_am : float
        Amplitude modulation frequency.
    sam_df : float
        Difference from the stimulus to the current fish eodf.
    sam_eodf : float
        The current EODf.
    sam_nyquist : float
        The Nyquist frequency of the EODf.
    sam_stim : float
        The stimulus frequency.

    '''
    # create lists for the values we want
    amplitudes = []
    dfs = []
    eodfs = []
    stim_freqs = []
    amp_mods = []
    ny_freqs = []
    
    # get the stimuli
    stimuli = sam.stimuli
    
    # loop over the stimuli
    for stim in stimuli:
        amplitude, df, eodf, stim_freq, amp_mod, ny_freq = extract_stim_data(stim)
        amplitudes.append(amplitude)
        dfs.append(df)
        eodfs.append(eodf)
        stim_freqs.append(stim_freq)
        amp_mods.append(amp_mod)
        ny_freqs.append(ny_freq)
      
    # get the means
    sam_amp = np.mean(amplitudes)
    sam_am = np.mean(amp_mods)
    sam_df = np.mean(dfs)
    sam_eodf = np.mean(eodfs)
    sam_nyquist = np.mean(ny_freqs)
    sam_stim = np.mean(stim_freqs)
    return sam_amp, sam_am,sam_df, sam_eodf, sam_nyquist, sam_stim

def spike_times(stim):
    """
    Reads out the spike times and other necessary parameters

    Parameters
    ----------
    stim : Stimulus object or rlxnix.base.repro module
        The stimulus from which the spike times should be calculated.

    Returns
    -------
    spike_times : np.array
        The spike times of the stimulus.
    stim_dur : float
        The duration of the stimulus.
    dt : float
        Time interval between two data points.

    """
    # reads out the spike times
    spikes, _ = stim.trace_data('Spikes-1')
    # reads out the duration
    stim_dur = stim.duration
    # get the stimulus interval
    ti = stim.trace_info("V-1")
    dt = ti.sampling_interval  
    return spikes, stim_dur, dt # se changed spike_times to spikes so its not the same as name of function


def valid_integrals(integral, local_mean, threshold, point):
    """
    Check if the integral exceeds the threshold compared to the local mean and 
    provide feedback on whether the given point is valid or not.

    Parameters
    ----------
    integral : float
        The calculated integral around the point.
    local_mean : float
        The local mean value (adjacent integrals).
    threshold : float
        Threshold value to compare integrals with local mean.
    point : float
        The harmonic frequency point being evaluated.

    Returns
    -------
    valid : bool
        True if the integral exceeds the local mean by the threshold, otherwise False.
    message : str
        A message stating whether the point is valid or not.
    """
    valid = integral > (local_mean * threshold)
    if valid:
        message = f"The point {point} is valid, as its integral exceeds the threshold."
    else:
        message = f"The point {point} is not valid, as its integral does not exceed the threshold."
    return valid, message

'''TODO Sarah: AM-freq plot:
    meaning of am peak in spectrum? why is it there how does it change with stim intensity? 
    make plot with AM 1/2 EODf over stim frequency (df+eodf), get amplitude of am peak and plot
    amplitude over frequency of peak'''