import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import rlxnix as rlx
import useful_functions as f
from matplotlib.lines import Line2D
from tqdm import tqdm

# plot the tuning curves for all cells y/n
single_plots = True

# all files we want to use
files = glob.glob("../data/2024-10-*.nix")

#EODf file for either day
eodf_file_w = glob.glob('../data/EOD_only/*-16*.nix')[0]
eodf_file_m = glob.glob('../data/EOD_only/*-21*.nix')[0]

# get only the good and fair filepaths
new_files = f.remove_poor(files)

#get the filenames as labels for plotting
labels = [os.path.splitext(os.path.basename(file))[0] for file in new_files]

# dict for all the different contrasts
contrast_files = {20 : {'power' :[], 'freq' : []},
                  10 : {'power' :[], 'freq' : []},
                   5 : {'power' :[], 'freq' : []}}
norm_contrast_files = {20 : {'power' :[], 'freq' : []},
                       10 : {'power' :[], 'freq' : []},
                       5 : {'power' :[], 'freq' : []}}

# loop over all the good files
for u, file in tqdm(enumerate(new_files), total = len(new_files)):
    #use correct eodf file
    if "-16" in file:
       orig_eodf = f.true_eodf(eodf_file_w)
    else:
       orig_eodf = f.true_eodf(eodf_file_m)
       
    #define lists
    contrast_frequencies = []
    contrast_powers = []
    # load a file
    dataset = rlx.Dataset(file)
    # extract sams
    sams = dataset.repro_runs('SAM')
    # get arrays for frequnecies and power
    stim_frequencies = np.zeros(len(sams))
    peak_powers = np.zeros_like(stim_frequencies)
    contrast_sams = f.contrast_sorting(sams)
    
    eodfs = []
    # loop over the contrasts
    for key in contrast_sams:
        stim_frequencies = np.zeros(len(contrast_sams[key]))
        norm_stim_frequencies = np.zeros_like(stim_frequencies)
        peak_powers = np.zeros_like(stim_frequencies)
        
        for i, sam in enumerate(contrast_sams[key]):
            # get stimulus frequency and stimuli
            _, _, _, _, eodf, _, stim_frequency = f.sam_data(sam)
            sam_frequency, sam_power = f.sam_spectrum(sam)
            # detect peaks
            _, _, peak_powers[i] = f.calculate_integral(sam_frequency, 
                                                          sam_power, stim_frequency)
            
            # add the current stimulus frequency
            stim_frequencies[i] = stim_frequency
            norm_stim_frequencies[i] = stim_frequency - orig_eodf
            eodfs.append(eodf)
        # replae zeros with NaN
        peak_powers = np.where(peak_powers == 0, np.nan, peak_powers)
        contrast_frequencies.append(stim_frequencies)
        contrast_powers.append(peak_powers)
        if key == 20:
            contrast_files[20]['freq'].append(stim_frequencies)
            contrast_files[20]['power'].append(peak_powers)
            norm_contrast_files[20]['freq'].append(norm_stim_frequencies)
            norm_contrast_files[20]['power'].append(peak_powers)
        elif key == 10:
            contrast_files[10]['freq'].append(stim_frequencies)
            contrast_files[10]['power'].append(peak_powers)
            norm_contrast_files[10]['freq'].append(norm_stim_frequencies)
            norm_contrast_files[10]['power'].append(peak_powers)
        else:
            contrast_files[5]['freq'].append(stim_frequencies)
            contrast_files[5]['power'].append(peak_powers)
            norm_contrast_files[5]['freq'].append(norm_stim_frequencies)
            norm_contrast_files[5]['power'].append(peak_powers)
    
    curr_eodf = np.mean(eodfs)
    if single_plots == True:
        # one cell with all contrasts in one subplot
        fig, ax = plt.subplots()
        ax.plot(contrast_frequencies[0], contrast_powers[0])
        ax.plot(contrast_frequencies[1], contrast_powers[1])
        if contrast_frequencies and contrast_frequencies[-1].size == 0:
            if contrast_frequencies and contrast_frequencies[-2].size == 0:
                ax.set_xlim(0,2000)
            else:
                ax.set_xlim(0,np.max(contrast_frequencies[-2]))
        else:
            ax.plot(contrast_frequencies[2], contrast_powers[2])
            ax.set_xlim(0,np.max(contrast_frequencies[-1]))
        ax.axvline(orig_eodf, color = 'black',linestyle = 'dashed', alpha = 0.8)
        ax.axvline(2*curr_eodf, color = 'black', linestyle = 'dotted', alpha = 0.8)
        ax.set_ylim(0, 0.00014)
        ax.set_xlabel('stimulus frequency [Hz]')
        ax.set_ylabel(r' power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]')
        ax.set_title(f"{file}")
        fig.legend(labels = ['20 % contrast', '10 % contrast','5 % contrast','EODf of awake fish', '1st harmonic of current EODf' ], loc = 'lower center', ncol = 3)
        plt.tight_layout(rect=[0, 0.06, 1, 1])
        plt.savefig(f'../results/tuning_curve{labels[u]}.svg')
        
        #one cell with the contrasts in different subplots
        fig, axs = plt.subplots(1, 3, figsize = [10,6], sharex = True, sharey = True)
        for p, key in enumerate(contrast_files):
            ax = axs[p]
            ax.plot(contrast_files[key]['freq'][-1],contrast_files[key]['power'][-1])
            ax.set_title(f"{key}")
            ax.axvline(orig_eodf, color = 'black',linestyle = 'dashed')
            ax.axvline(2*curr_eodf, color = 'darkblue', linestyle = 'dotted', alpha = 0.8)
            if p == 0:
                ax.set_ylabel(r'power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]', fontsize=12)
        fig.supxlabel('stimulus frequency [Hz]', fontsize=12)
        fig.suptitle(f'{labels[u]}')
        fig.legend(labels = ['power of stimulus peak', 'EODf of awake fish','1st harmonic of current EODf'], loc = 'lower center', bbox_to_anchor=(0.5, 0.05), ncol = 3)
        plt.tight_layout(rect=[0, 0.06, 1, 1])
        plt.savefig(f'../results/contrast_tuning{labels[u]}.svg')
        
cmap = plt.get_cmap('viridis')
colors = cmap(np.linspace(0, 1, len(new_files)))
plt.close('all')
if len(new_files) < 10:
    lines = []
    labels_legend = []
    fig, axs = plt.subplots(1, 3, figsize = [10,6], sharex = True, sharey = True)
    for p, key in enumerate(contrast_files):
        ax = axs[p]
        for i in range(len(contrast_files[key]['power'])):
            line, = ax.plot(contrast_files[key]['freq'][i],contrast_files[key]['power'][i], label = labels[i], color = colors[i])
            ax.set_title(f"{key}")
            ax.axvline(orig_eodf, color = 'black',linestyle = 'dashed')
            if p == 0:
                lines.append(line)
                labels_legend.append(labels[i])
    fig.supxlabel('stimulus frequency [Hz]', fontsize=12)
    fig.supylabel(r'power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]', fontsize=12)
    
    # Create a single legend beneath the plots with 3 columns
    lines.append(Line2D([0], [0], color='black', linestyle='--'))  # Custom line for the legend
    labels_legend.append("Awake fish EODf")  # Custom label
    fig.legend(lines, labels_legend, loc='upper center', ncol=3, fontsize=10)
    plt.tight_layout(rect=[0, 0, 1, 0.85])  # Adjust layout to make space for the legend
    if "-16" in new_files[-1]:
        plt.savefig('../results/tuning_curves_10_16.svg')
    elif "-21" in new_files[0]:
        plt.savefig('../results/tuning_curves_10_21.svg')
else:
    for o in range(2):
        lines = []
        labels_legend = []
        fig, axs = plt.subplots(1, 3, figsize = [10,6], sharex = True, sharey = True)
        for p, key in enumerate(norm_contrast_files):
            ax = axs[p]
            for i in range(len(norm_contrast_files[key]['power'])):
                line, = ax.plot(norm_contrast_files[key]['freq'][i],norm_contrast_files[key]['power'][i], label = labels[i], color = colors[i])
                ax.set_title(f"{key}")
                ax.axvline(0, color = 'black',linestyle = 'dashed')
                if p == 0:
                    lines.append(line)
                    labels_legend.append(labels[i])
        fig.supylabel(r'power [$\frac{\mathrm{mV^2}}{\mathrm{Hz}}$]', fontsize=12)
        
        # Create a single legend beneath the plots with 3 columns
        lines.append(Line2D([0], [0], color='black', linestyle='--'))  # Custom line for the legend
        labels_legend.append("Awake fish EODf")  # Custom label
        fig.legend(lines, labels_legend, loc='upper center', ncol=3, fontsize=10)
        plt.tight_layout(rect=[0, 0, 1, 0.82])  # Adjust layout to make space for the legend
        if o == 0:
            ax.set_xlim(-600, 2100)
            fig.supxlabel('stimulus frequency [Hz]', fontsize=12)
            plt.savefig('../results/tuning_curves_norm.svg')
        else:
            ax.set_xlim(-600, 600)
            fig.supxlabel(' relative stimulus frequency [Hz]', fontsize=12)
            plt.savefig('../results/tuning_curves_norm_zoom.svg')
#plt.close('all')