import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
import helper_functions as hf
import do_check_for_overlap as cfo
from params import *
import matplotlib.colors as mcolors
import os

if __name__ == '__main__':

    ###################################################################################################################
    # load data
    ###################################################################################################################
    # load all the data of one day
    filename = sorted(os.listdir('../data/'))[2]

    ident = np.load('../../../data_masterthesis/'+filename+'/all_ident_v.npy', allow_pickle=True)
    freq = np.load('../../../data_masterthesis/'+filename+'/all_fund_v.npy', allow_pickle=True)
    timeidx = np.load('../../../data_masterthesis/'+filename+'/all_idx_v.npy', allow_pickle=True)
    t = np.load('../../../data_masterthesis/'+filename+'/all_times.npy', allow_pickle=True)

    aifl = np.load('../data/'+filename+'/aifl2.npy', allow_pickle=True)
    faifl = np.load('../data/'+filename+'/faifl2.npy', allow_pickle=True)
    oofl = np.load('../data/'+filename+'/oofl.npy', allow_pickle=True)
    faifl = np.delete(faifl, [0], axis=0)

    fish_in_aifl = list(np.unique(np.where(~np.isnan(aifl[:, :, 0]))[1]))
    fish_in_faifl = list(np.unique(faifl[:, [1, 3]]))
    correct_aifl = sorted(list(set(fish_in_aifl) - set(fish_in_faifl)))

    dt = datetime.datetime.strptime(filename[-5:], '%H_%M')
    embed()
    quit()

    ###################################################################################################################
    # plot traces in oofl
    counter = 0
    oofl = np.array(oofl)
    for i in range(len(oofl[:, 0])):
        channel = int(oofl[i, 0])
        time_diff = timeidx[channel][ident[channel] == oofl[i][1]][-1] - timeidx[channel][ident[channel] == oofl[i][1]][0]
        if time_diff >= 0: #4800
            plt.plot(timeidx[channel][ident[channel] == oofl[i][1]],
                     freq[channel][ident[channel] == oofl[i][1]], Linewidth=2)
            plt.text(timeidx[channel][ident[channel] == oofl[i][1]][0] + np.random.rand(1) * 0.3,
                     freq[channel][ident[channel] == oofl[i][1]][0] + np.random.rand(1) * 0.3,
                     str(oofl[i][0]) + '_' + str(oofl[i][1]), color='blue')
            counter = counter + 1
    plt.show()

    ###################################################################################################################
    # plot overlapping traces
    new_sorting = cfo.get_list_of_fishN_with_overlap(aifl, fish_in_aifl, timeidx, ident)
    for fish_number in new_sorting:
        hf.plot_together(timeidx, freq, ident, aifl, int(fish_number), color_vec[0])

    ###################################################################################################################
    # plot fish in faifl
    for i in range(len(faifl)):
        fishid1 = int(faifl[i, 1])
        fishid2 = int(faifl[i, 3])
        hf.plot_all_channels(timeidx, freq, ident, aifl, fishid1, fishN2=fishid2)

    ###################################################################################################################
    # plot all traces

    fig, ax = plt.subplots(1, 1, figsize=(15 / inch, 8 / inch))
    fig.subplots_adjust(left=0.12, bottom=0.15, right=0.98, top=0.98)
    for color_counter, fish_number in enumerate(fish_in_aifl):
        for channel_idx in [13]:
            fish1 = aifl[channel_idx, fish_number, ~np.isnan(aifl[channel_idx, fish_number])]
            r1 = len(fish1)
            print(fish1)
            for len_idx1 in range(r1):
                zeit = t[timeidx[channel_idx][ident[channel_idx] == fish1[len_idx1]]]
                plot_zeit = []
                for i in range(len(zeit)):
                    plot_zeit.append(dt + datetime.timedelta(seconds=zeit[i]))
                plt.plot(plot_zeit,
                         freq[channel_idx][ident[channel_idx] == fish1[len_idx1]],
                         Linewidth=1, label=fish_number, color=color_vec[color_counter+40])

    ax.set_ylim([450, 1000])
    ax.set_xlabel('Time', fontsize=fs)
    ax.set_ylabel('EOD frequency [Hz]', fontsize=fs)
    ax.make_nice_ax()
    ax.timeaxis()
    fig.savefig(save_path_pres+'EOD_sorter.pdf')
    fig.savefig('../../thesis/Figures/Methods/EOD_sorter.pdf')
    plt.show()

    ###################################################################################################################
    # plot
    u = np.unique(faifl[:, [1, 3]])
    for fish_number in range(len(u)):
        hf.plot_together(timeidx, freq, ident, aifl, int(u[fish_number]), color_vec[fish_number])