# -*- coding: utf-8 -*-
"""
Created on Tue Oct 22 15:21:41 2024

@author: diana
"""

import glob
import os
import rlxnix as rlx
import numpy as np
import matplotlib.pyplot as plt
import scipy.signal as sig
from scipy.integrate import quad


### FUNCTIONS ###
def binary_spikes(spike_times, duration, dt): 
    """Converts the spike times to a binary representation. 
    Zeros when there is no spike, one when there is.

    Parameters
    ----------
    spike_times : np.array
        The spike times.
    duration : float
        The trial duration.
    dt : float
        The temporal resolution.

    Returns
    -------
    binary : np.array
        The binary representation of the spike times.
    """
    binary = np.zeros(int(np.round(duration / dt)))  #Vektor, der genauso lang ist wie die stim time
    spike_indices = np.asarray(np.round(spike_times / dt), dtype=int)
    binary[spike_indices] = 1
    return binary


def firing_rate(binary_spikes, box_width, dt=0.000025):
    """Calculate the firing rate from binary spike data.

    Parameters
    ----------
    binary_spikes : np.array
        A binary array representing spike occurrences.
    box_width : float
        The width of the box filter in seconds.
    dt : float, optional
        The temporal resolution (time step) in seconds. Default is 0.000025 seconds.

    Returns
    -------
    rate : np.array
        An array representing the firing rate at each time step.
    """
    box = np.ones(int(box_width // dt))
    box /= np.sum(box) * dt  # Normalization of box kernel to an integral of 1
    rate = np.convolve(binary_spikes, box, mode="same")
    return rate


def powerspectrum(rate, dt):
    """Compute the power spectrum of a given firing rate.

    This function calculates the power spectrum using the Welch method.

    Parameters
    ----------
    rate : np.array
        An array of firing rates.
    dt : float
        The temporal resolution (time step) in seconds.

    Returns
    -------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    """
    frequency, power = sig.welch(rate, fs=1/dt, nperseg=2**15, noverlap=2**14)
    return frequency, power


def calculate_integral(frequency, power, point, delta):   
    """
    Calculate the integral around a single specified point.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    point : float
        The harmonic frequency at which to calculate the integral.
    delta : float
        Half-width of the range for integration around the point.

    Returns
    -------
    integral : float
        The calculated integral around the point.
    local_mean : float
        The local mean value (adjacent integrals).
    """
    indices = (frequency >= point - delta) & (frequency <= point + delta)
    integral = np.trapz(power[indices], frequency[indices])
            
    left_indices = (frequency >= point - 5 * delta) & (frequency < point - delta)
    right_indices = (frequency > point + delta) & (frequency <= point + 5 * delta)
           
    l_integral = np.trapz(power[left_indices], frequency[left_indices])
    r_integral = np.trapz(power[right_indices], frequency[right_indices])
           
    local_mean = np.mean([l_integral, r_integral])
    return integral, local_mean


def valid_integrals(integral, local_mean, threshold, point):
    """
    Check if the integral exceeds the threshold compared to the local mean and 
    provide feedback on whether the given point is valid or not.

    Parameters
    ----------
    integral : float
        The calculated integral around the point.
    local_mean : float
        The local mean value (adjacent integrals).
    threshold : float
        Threshold value to compare integrals with local mean.
    point : float
        The harmonic frequency point being evaluated.

    Returns
    -------
    valid : bool
        True if the integral exceeds the local mean by the threshold, otherwise False.
    message : str
        A message stating whether the point is valid or not.
    """
    valid = integral > (local_mean * threshold)
    if valid:
        message = f"The point {point} is valid, as its integral exceeds the threshold."
    else:
        message = f"The point {point} is not valid, as its integral does not exceed the threshold."
    return valid, message


def prepare_harmonics(frequencies, categories, num_harmonics, colors):
    """
    Prepare harmonic frequencies and assign colors based on categories.
    
    Parameters
    ----------
    frequencies : list
        Base frequencies to generate harmonics.
    categories : list
        Corresponding categories for the base frequencies.
    num_harmonics : list
        Number of harmonics for each base frequency.
    colors : list
        List of colors corresponding to the categories.

    Returns
    -------
    points : list
        A flat list of harmonic frequencies.
    color_mapping : dict
        A dictionary mapping each category to its corresponding color.
    points_categories : dict
        A mapping of categories to their harmonic frequencies.
    """
    points_categories = {}
    for idx, (freq, category) in enumerate(zip(frequencies, categories)):
        points_categories[category] = [freq * (i + 1) for i in range(num_harmonics[idx])]

    points = [p for harmonics in points_categories.values() for p in harmonics]
    color_mapping = {category: colors[idx] for idx, category in enumerate(categories)}

    return points, color_mapping, points_categories


def find_exceeding_points(frequency, power, points, delta, threshold):
    """
    Find the points where the integral exceeds the local mean by a given threshold.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    points : list
        A list of harmonic frequencies to evaluate.
    delta : float
        Half-width of the range for integration around the point.
    threshold : float
        Threshold value to compare integrals with local mean.

    Returns
    -------
    exceeding_points : list
        A list of points where the integral exceeds the local mean by the threshold.
    """
    exceeding_points = []
    
    for point in points:
        # Calculate the integral and local mean for the current point
        integral, local_mean = calculate_integral(frequency, power, point, delta)
        
        # Check if the integral exceeds the threshold
        valid, message = valid_integrals(integral, local_mean, threshold, point)
        
        if valid:
            exceeding_points.append(point)
    
    return exceeding_points


def plot_highlighted_integrals(frequency, power, exceeding_points, delta, threshold, color_mapping, points_categories):
    """
    Plot the power spectrum and highlight integrals that exceed the threshold.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    exceeding_points : list
        A list of harmonic frequencies that exceed the threshold.
    delta : float
        Half-width of the range for integration around each point.
    threshold : float
        Threshold value to compare integrals with local mean.
    color_mapping : dict
        A dictionary mapping each category to its color.
    points_categories : dict
        A mapping of categories to lists of points.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The created figure object with highlighted integrals.
    """
    fig, ax = plt.subplots()
    ax.plot(frequency, power)  # Plot power spectrum
    
    for point in exceeding_points:
        integral, local_mean = calculate_integral(frequency, power, point, delta)
        valid, _ = valid_integrals(integral, local_mean, threshold, point)
        if valid:  
            # Define color based on the category of the point
            color = next((c for cat, c in color_mapping.items() if point in points_categories[cat]), 'gray')
            # Shade the region around the point where the integral was calculated
            ax.axvspan(point - delta, point + delta, color=color, alpha=0.3, label=f'{point:.2f} Hz')
            print(f"Integral around {point:.2f} Hz: {integral:.5e}")

                        
            # Define left and right boundaries of adjacent regions
            left_boundary = frequency[np.where((frequency >= point - 5 * delta) & (frequency < point - delta))[0][0]]
            right_boundary = frequency[np.where((frequency > point + delta) & (frequency <= point + 5 * delta))[0][-1]]

            # Add vertical dashed lines at the boundaries of the adjacent regions
            ax.axvline(x=left_boundary, color="k", linestyle="--")
            ax.axvline(x=right_boundary, color="k", linestyle="--")
            
            
    ax.set_xlim([0, 1200])
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power')
    ax.set_title('Power Spectrum with Highlighted Integrals')
    ax.legend()
    
    return fig


### Data retrieval ###
datafolder = "../data"
example_file = os.path.join("..", "data", "2024-10-16-ad-invivo-1.nix")
dataset = rlx.Dataset(example_file)
sams = dataset.repro_runs("SAM") 
sam = sams[2]

## Data for functions
df = sam.metadata["RePro-Info"]["settings"]["deltaf"][0][0]
stim = sam.stimuli[1]
potential, time = stim.trace_data("V-1")
spikes, _ = stim.trace_data("Spikes-1")
duration = stim.duration
dt = stim.trace_info("V-1").sampling_interval


### Apply Functions to calculate data ###
b = binary_spikes(spikes, duration, dt)
rate = firing_rate(b, box_width=0.05, dt=dt)
frequency, power = powerspectrum(b, dt)


### Important stuff ###
## Frequencies
eodf = stim.metadata[stim.name]["EODf"][0][0]
stimulus_frequency = eodf + df
AM = 50  # Hz
frequencies = [AM, eodf, stimulus_frequency]

categories = ["AM", "EODf", "Stimulus frequency"]
num_harmonics = [4, 2, 2]
colors = ["green", "orange", "red"]

delta = 2.5
threshold = 10

### Apply functions to make powerspectrum ###
integral, local = calculate_integral(frequency, power, eodf, delta)
valid = valid_integrals(integral, local, threshold, eodf)
points, color, categories = prepare_harmonics(frequencies, categories, num_harmonics, colors)
print(len(points))
exceeding = find_exceeding_points(frequency, power, points, delta, threshold)
print(len(exceeding))

## Plot power spectrum and highlight integrals
fig = plot_highlighted_integrals(frequency, power, points, delta, threshold, color, categories)
plt.show()