import glob
import pathlib
import numpy as np
import matplotlib.pyplot as plt
import rlxnix as rlx
from IPython import embed
from scipy.signal import welch

def AM(EODf, stimulus):
    """
    Calculates the Amplitude Modulation and Nyquist frequency

    Parameters
    ----------
    EODf : float or int
        The current EODf.
    stimulus : float or int
        The absolute frequency of the stimulus.

    Returns
    -------
    AM : float 
        The amplitude modulation resulting from the stimulus.
    nyquist : float
        The maximum frequency possible to resolve with the EODf.

    """
    nyquist = EODf * 0.5
    AM = np.mod(stimulus, nyquist)
    return AM, nyquist

def binary_spikes(spike_times, duration, dt):
    """
    Converts the spike times to a binary representations

    Parameters
    ----------
    spike_times : np.array
        The spike times.
    duration : float
        The trial duration:
    dt : float
        The temporal resolution.

    Returns
    -------
    binary : np.array
        The binary representation of the spike train.

    """
    binary = np.zeros(int(np.round(duration / dt))) #create the binary array with the same length as potential
    
    spike_indices = np.asarray(np.round(spike_times / dt), dtype = int) # get the indices
    binary[spike_indices] = 1 # put the indices into binary
    return binary

def extract_stim_data(stimulus):
    '''
    extracts all necessary metadata for each stimulus

    Parameters
    ----------
    stimulus : Stimulus object or rlxnix.base.repro module
        The stimulus from which the data is needed.

    Returns
    -------
    amplitude : float
        The relative signal amplitude in percent.
    df : float
        Distance of the stimulus to the current EODf.
    eodf : float
        Current EODf.
    stim_freq : float
        The total stimulus frequency (EODF+df).
    amp_mod : float
        The current amplitude modulation.
    ny_freq : float
        The current nyquist frequency.

    '''
    # extract metadata
    # the stim.name adjusts the first key as it changes with every stimulus
    amplitude = stimulus.metadata[stimulus.name]['Contrast'][0][0] 
    df = stimulus.metadata[stimulus.name]['DeltaF'][0][0]
    eodf = round(stimulus.metadata[stimulus.name]['EODf'][0][0])
    stim_freq = round(stimulus.metadata[stimulus.name]['Frequency'][0][0])
    # calculates the amplitude modulation
    amp_mod, ny_freq = AM(eodf, stim_freq)
    return amplitude, df, eodf, stim_freq, amp_mod, ny_freq

def firing_rate(binary_spikes, dt = 0.000025, box_width = 0.01):
    '''
    Calculates the firing rate from binary spikes

    Parameters
    ----------
    binary_spikes : np.array
        The binary representation of the spike train.
    dt : float, optional
        Time difference between two datapoints. The default is 0.000025.
    box_width : float, optional
        Time window on which the rate should be computed on. The default is 0.01.

    Returns
    -------
    rate : np.array
        Array of firing rates.

    '''
    box = np.ones(int(box_width // dt))
    box /= np.sum(box) * dt # normalisierung des box kernels to an integral of one
    rate = np.convolve(binary_spikes, box, mode = 'same') 
    return rate

def power_spectrum(spike_times, duration, dt):
    '''
    Computes a power spectrum based on the spike times

    Parameters
    ----------
    spike_times : np.array
        The spike times.
    duration : float
        The trial duration:
    dt : float
        The temporal resolution.

    Returns
    -------
    freq : np.array
        All the frequencies of the power spectrum.
    power : np.array
        Power of the frequencies calculated.

    '''
    # binarizes spikes
    binary = binary_spikes(spike_times, duration, dt)
    # computes firing rates
    rate = firing_rate(binary, dt = dt)
    # creates power spectrum
    freq, power = welch(rate, fs = 1/dt, nperseg = 2**16, noverlap = 2**15)
    return freq, power

def remove_poor(files):
    """
    Removes poor datasets from the set of files for analysis

    Parameters
    ----------
    files : list 
        list of files.

    Returns
    -------
    good_files : list
        list of files without the ones with the label poor.

    """
    # create list for good files
    good_files = []
    # loop over files
    for i in range(len(files)):
        # print(files[i])
        # load the file (takes some time)
        data = rlx.Dataset(files[i])
        # get the quality
        quality = str.lower(data.metadata["Recording"]["Recording quality"][0][0])
        # check the quality
        if quality != "poor":
            # if its good or fair add it to the good files
            good_files.append(files[i])
    return good_files