import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
import helper_functions as hf

if __name__ == '__main__':

    ###################################################################################################################
    # load data
    ###################################################################################################################
    # load all the data of one day
    ident = np.load('../../../data/2019-10-07-18_28/all_ident_v.npy', allow_pickle=True)
    # power = np.load('../../../data/2019-10-07-18_28/all_sign_v.npy', allow_pickle=True)
    freq = np.load('../../../data/2019-10-07-18_28/all_fund_v.npy', allow_pickle=True)
    timeidx = np.load('../../../data/2019-10-07-18_28/all_idx_v.npy', allow_pickle=True)

    aifl = np.load('../data/aifl.npy', allow_pickle=True)
    oofl = np.load('../data/oofl2.npy', allow_pickle=True)
    faifl = np.load('../data/faifl.npy', allow_pickle=True)
    faifl = np.delete(faifl, [0], axis=0)

    uni_ident_list = []
    for index in range(len(ident)):
        uni_ident = np.unique(ident[index][~np.isnan(ident[index])])
        uni_ident_list.append(uni_ident)
    for k in range(len(oofl)-len(oofl[np.where(oofl[:, 0] == 15)[0][0]:, :])):
        channel = int(oofl[k, 0])

        # find for each identity the first and last time stamp in the given channel and the next
        # channel 0
        t_Ch0 = np.empty((1, 2))
        id0 = oofl[k,1]
        t = timeidx[channel][ident[channel] == id0]
        t_Ch0[0][0] = t[0]
        t_Ch0[0][1] = t[-1]

        # channel 1
        t_Ch1 = np.empty((len(uni_ident_list[channel + 1]), 2))
        for index in range(len(uni_ident_list[channel + 1])):
            id1 = uni_ident_list[channel + 1][index]
            t1 = timeidx[channel + 1][ident[channel + 1] == id1]
            t_Ch1[index, 0] = t1[0]
            t_Ch1[index, 1] = t1[-1]

        # parameters
        jitter_time = 5
        tolerance_time = 120

        false_count = 0
        true_count = 0
        for i in range(len(t_Ch0)):
            for j in range(len(t_Ch1)):
                t0 = timeidx[channel][ident[channel] == id0]
                f0 = freq[channel][ident[channel] == id0]

                id1 = uni_ident_list[channel + 1][j]
                t1 = timeidx[channel + 1][ident[channel + 1] == id1]
                f1 = freq[channel + 1][ident[channel + 1] == id1]

                t00 = t_Ch0[i][0] - jitter_time
                t0n = t_Ch0[i][1] + jitter_time

                t10 = t_Ch1[j][0] - jitter_time
                t1n = t_Ch1[j][1] + jitter_time

                window_times = sorted(np.array([t00, t0n, t10, t1n]))[1:3]

                sort_it = False
                case = np.nan
                if (t00 <= t10) and (t00 <= t1n) and (t0n >= t1n):
                    sort_it = True
                    case = 1
                elif (t10 <= t00) and (t10 <= t0n) and (t1n >= t0n):
                    sort_it = True
                    case = 2
                elif (t00 <= t10) and (t0n >= t10) and (t0n <= t1n):
                    sort_it = True
                    case = 3
                elif (t10 <= t00) and (t1n >= t00) and (t1n <= t0n):
                    sort_it = True
                    case = 4
                else:
                    false_count += 1
                    pass

                # if t_Ch0[i][0] - jitter_time <= t_Ch1[j][0] <= t_Ch0[i][1] + jitter_time:
                if sort_it:

                    true_count+=1
                    f0_box_min = np.median(f0[np.isclose(t0, window_times[0], atol=tolerance_time)])
                    f0_box_max = np.median(f0[np.isclose(t0, window_times[1], atol=tolerance_time)])
                    f1_box_min = np.median(f1[np.isclose(t1, window_times[0], atol=tolerance_time)])
                    f1_box_max = np.median(f1[np.isclose(t1, window_times[1], atol=tolerance_time)])
                    if f0_box_min.size > 0 and f1_box_min.size > 0:
                        fdiff0 = abs(f0_box_min - f1_box_min)
                        if fdiff0 <= 10:
                            plt.plot(timeidx[channel][ident[channel] == id0],
                                     freq[channel][ident[channel] == id0], label=id0)
                            plt.plot(timeidx[channel + 1][ident[channel + 1] == id1],
                                     freq[channel + 1][ident[channel + 1] == id1], label=id1)
                            plt.legend()
                            plt.show()
                            embed()
                            #hf.fill_aifl(id0, id1, aifl, channel, 4, timeidx, freq, ident, faifl)
                        else:
                            continue
                    elif f0_box_max.size > 0 and f1_box_max.size > 0:
                        fdiff1 = abs(f0_box_max - f1_box_max)
                        if fdiff1 <= 10:
                            plt.plot(timeidx[channel][ident[channel] == id0],
                                     freq[channel][ident[channel] == id0], label=id0)
                            plt.plot(timeidx[channel + 1][ident[channel + 1] == id1],
                                     freq[channel + 1][ident[channel + 1] == id1], label=id1)
                            plt.legend()
                            plt.show()
                            embed()
                            #hf.fill_aifl(id0,id1,aifl,channel,4,timeidx,freq, ident,faifl)
                        else:
                            continue

    oofl = []
    counter = 0
    for idx in range(len(uni_ident_list)):
        for j in range(len(uni_ident_list[idx])):
            if not np.any(aifl[idx] == uni_ident_list[idx][j]):
                oofl.append([idx, uni_ident_list[idx][j]])
                counter = counter + 1

    embed()
    quit()