'''This script contains all functions for various plots that could be relevant 
for the presentation or protocol of the Grewe GP 2024'''

import os
import matplotlib.pyplot as plt
import numpy as np
import rlxnix as rlx
from useful_functions import power_spectrum
import sys

'''IMPORT DATA'''
datafolder = '../data' #./ wo ich gerade bin; ../ eine ebene höher; ../../ zwei ebenen höher
example_file = os.path.join('..', 'data', '2024-10-16-ac-invivo-1.nix')

'''EXTRACT DATA'''
dataset = rlx.Dataset(example_file)

# get sams
sams = dataset.repro_runs('SAM')
sam = sams[2]

# get stims
stimulus = sam.stimuli[-1]
stim_count = sam.stimulus_count

'''PLOTS'''
# create colormap
colors = plt.cm.prism(np.linspace(0, 1, stim_count))

# plot timeline of whole rec
dataset.plot_timeline()

# plot voltage over time for whole trace
def plot_vt_spikes(t, v, spike_t):
    fig = plt.figure(figsize=(5, 2.5))
    # alternative to ax = axs[0]
    ax = fig.add_subplot()
    # plot vt diagram
    ax.plot(t[t<0.1], v[t<0.1])
    # plot spikes into vt diagram, at max V
    ax.scatter(spike_t[spike_t<0.1], np.ones_like(spike_t[spike_t<0.1]) * np.max(v))
    plt.show()

# plot scatter plot for one sam with all 3 stims
def scatter_plot(colormap, stimuli_list, stimulus_count):
    fig = plt.figure()
    ax = fig.add_subplot()
        
    ax.eventplot(stimuli_list, colors=colormap)
    ax.set_xlabel('Spike Times [ms]')
    ax.set_ylabel('Loop #')
    ax.set_yticks(range(stimulus_count))
    ax.set_title('Spikes of SAM 3')
    plt.show()
    
# calculate power spectrum
freq, power = power_spectrum(stimulus)

# plot power spectrum
def power_spectrum_plot(f, p):
    # plot power spectrum
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.plot(freq, power)
    ax.set_xlabel('Frequency [Hz]')
    ax.set_ylabel('Power [1/Hz]')
    ax.set_xlim(0, 1000)
    plt.show()
    
# DIANAS POWER SPECTRUM PLOT
functions_path = r"C:\Users\diana\OneDrive - UT Cloud\Master\GPs\GP1_Grewe\Projekt\gpgrewe2024\code"
sys.path.append(functions_path)
import useful_functions as u

def plot_highlighted_integrals(frequency, power, points, color_mapping, points_categories, delta = 2.5):
    """
    Plot the power spectrum and highlight integrals that exceed the threshold.

    Parameters
    ----------
    frequency : np.array
        An array of frequencies corresponding to the power values.
    power : np.array
        An array of power spectral density values.
    exceeding_points : list
        A list of harmonic frequencies that exceed the threshold.
    delta : float
        Half-width of the range for integration around each point.
    threshold : float
        Threshold value to compare integrals with local mean.
    color_mapping : dict
        A dictionary mapping each category to its color.
    points_categories : dict
        A mapping of categories to lists of points.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The created figure object with highlighted integrals.
    """
    fig, ax = plt.subplots()
    ax.plot(frequency, power)  # Plot power spectrum
    
    for point in points:
        # Use the imported function to calculate the integral and local mean
        integral, local_mean, _ = u.calculate_integral(frequency, power, point)
        
        # Use the imported function to check if the point is valid
        valid = u.valid_integrals(integral, local_mean, point)
        
        if valid:  
            # Define color based on the category of the point
            color = next((c for cat, c in color_mapping.items() if point in points_categories[cat]), 'gray')
            # Shade the region around the point where the integral was calculated
            ax.axvspan(point - delta, point + delta, color=color, alpha=0.3, label=f'{point:.2f} Hz')
            print(f"Integral around {point:.2f} Hz: {integral:.5e}")
                        
            # Define left and right boundaries of adjacent regions
            left_boundary = frequency[np.where((frequency >= point - 5 * delta) & (frequency < point - delta))[0][0]]
            right_boundary = frequency[np.where((frequency > point + delta) & (frequency <= point + 5 * delta))[0][-1]]

            # Add vertical dashed lines at the boundaries of the adjacent regions
            ax.axvline(x=left_boundary, color="k", linestyle="--")
            ax.axvline(x=right_boundary, color="k", linestyle="--")
    
    ax.set_xlim([0, 1200])
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power')
    ax.set_title('Power Spectrum with Highlighted Integrals')
    ax.legend()
    
    return fig