'''This script contains all functions for various plots that could be relevant 
for the presentation or protocol of the Grewe GP 2024'''

import os
import matplotlib.pyplot as plt
import numpy as np
import rlxnix as rlx
from useful_functions import power_spectrum

'''IMPORT DATA'''
datafolder = '../data' #./ wo ich gerade bin; ../ eine ebene höher; ../../ zwei ebenen höher
example_file = os.path.join('..', 'data', '2024-10-16-ac-invivo-1.nix')

'''EXTRACT DATA'''
dataset = rlx.Dataset(example_file)

# get sams
sams = dataset.repro_runs('SAM')
sam = sams[2]

# get stims
stimulus = sam.stimuli[-1]
stim_count = sam.stimulus_count

'''PLOTS'''
# create colormap
colors = plt.cm.prism(np.linspace(0, 1, stim_count))

# plot timeline of whole rec
dataset.plot_timeline()

# plot voltage over time for whole trace
def plot_vt_spikes(t, v, spike_t):
    fig = plt.figure(figsize=(5, 2.5))
    # alternative to ax = axs[0]
    ax = fig.add_subplot()
    # plot vt diagram
    ax.plot(t[t<0.1], v[t<0.1])
    # plot spikes into vt diagram, at max V
    ax.scatter(spike_t[spike_t<0.1], np.ones_like(spike_t[spike_t<0.1]) * np.max(v))
    plt.show()

# plot scatter plot for one sam with all 3 stims
def scatter_plot(colormap, stimuli_list, stimulus_count):
    fig = plt.figure()
    ax = fig.add_subplot()
        
    ax.eventplot(stimuli_list, colors=colormap)
    ax.set_xlabel('Spike Times [ms]')
    ax.set_ylabel('Loop #')
    ax.set_yticks(range(stimulus_count))
    ax.set_title('Spikes of SAM 3')
    plt.show()
    
# calculate power spectrum
freq, power = power_spectrum(stimulus)

# plot power spectrum
def power_spectrum_plot(f, p):
    # plot power spectrum
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.plot(freq, power)
    ax.set_xlabel('Frequency [Hz]')
    ax.set_ylabel('Power [1/Hz]')
    ax.set_xlim(0, 1000)
    plt.show()
    
####### ADD DIANAS POWER SPECTRUM PLOT
def plot_highlighted_integrals(frequency, power, exceeding_points, delta, threshold, color_mapping, points_categories):
    """
    Plot the power spectrum and highlight integrals that exceed the threshold.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    exceeding_points : list
        A list of harmonic frequencies that exceed the threshold.
    delta : float
        Half-width of the range for integration around each point.
    threshold : float
        Threshold value to compare integrals with local mean.
    color_mapping : dict
        A dictionary mapping each category to its color.
    points_categories : dict
        A mapping of categories to lists of points.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The created figure object with highlighted integrals.
    """
    fig, ax = plt.subplots()
    ax.plot(frequency, power)  # Plot power spectrum
    
    for point in exceeding_points:
        integral, local_mean = calculate_integral(frequency, power, point, delta)
        valid, _ = valid_integrals(integral, local_mean, threshold, point)
        if valid:  
            # Define color based on the category of the point
            color = next((c for cat, c in color_mapping.items() if point in points_categories[cat]), 'gray')
            # Shade the region around the point where the integral was calculated
            ax.axvspan(point - delta, point + delta, color=color, alpha=0.3, label=f'{point:.2f} Hz')
            print(f"Integral around {point:.2f} Hz: {integral:.5e}")

                        
            # Define left and right boundaries of adjacent regions
            left_boundary = frequency[np.where((frequency >= point - 5 * delta) & (frequency < point - delta))[0][0]]
            right_boundary = frequency[np.where((frequency > point + delta) & (frequency <= point + 5 * delta))[0][-1]]

            # Add vertical dashed lines at the boundaries of the adjacent regions
            ax.axvline(x=left_boundary, color="k", linestyle="--")
            ax.axvline(x=right_boundary, color="k", linestyle="--")
            
            
    ax.set_xlim([0, 1200])
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power')
    ax.set_title('Power Spectrum with Highlighted Integrals')
    ax.legend()
    
    return fig