import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
from IPython import embed
import helper_functions as hf
import os
import datetime
import time
import pandas as pd

if __name__ == '__main__':

    ###################################################################################################################
    # load data
    ###################################################################################################################
    # load all the data of one day
    start = time.time()

    filename = sorted(os.listdir('../../../data/'))[6]

    ident = np.load('../../../data/'+filename+'/all_ident_v.npy', allow_pickle=True)
    power = np.load('../../../data/'+filename+'/all_sign_v.npy', allow_pickle=True)
    freq = np.load('../../../data/'+filename+'/all_fund_v.npy', allow_pickle=True)
    timeidx = np.load('../../../data/'+filename+'/all_idx_v.npy', allow_pickle=True)
    realtime = np.load('../../../data/'+filename+'/all_times.npy', allow_pickle=True)
    temp = pd.read_csv('../../../data/' + filename + '/temperatures.csv', sep=';')
    temp_l = np.array(temp.values.tolist())

    aifl = np.load('../data/'+filename+'/aifl2.npy', allow_pickle=True)
    end = time.time()
    print(end - start)
    ###################################################################################################################
    # parameter and variables
    ###################################################################################################################
    # lists
    fish_in_aifl = list(np.unique(np.where(~np.isnan(aifl[:, :, 0]))[1]))

    # variables and empty lists
    sampling_rate = 1 / np.diff(realtime)[0]  # in sec
    # ipp
    power_means = []
    all_ipp = []
    all_xtickses = []
    all_run_mean = []
    all_run_std = []
    all_Ctime_gesamt = []
    all_max_ch = []
    all_hms = []
    thresholds = np.full(3, np.nan)

    ###################################################################################################################
    # analysis
    ###################################################################################################################

    for fish_number in fish_in_aifl:

        ###############################################################################################################
        # power cube: 1 dim: different channel for different traces
        #             2 dim: time dimension
        #             3 dim: again channel but in this direction the power of the traces of one channel are written into
        power_cube = np.full([16, len(realtime), 16], np.nan)

        for channel in range(16):
            fish_IDs = aifl[channel, fish_number, ~np.isnan(aifl[channel, fish_number])]
            for len_idx1 in range(len(fish_IDs)):
                ID = fish_IDs[len_idx1]
                t = timeidx[channel][ident[channel] == ID]
                p = power[channel][ident[channel] == ID]
                power_cube[channel, t] = p

        ###############################################################################################################
        # interpolated power pancake = ipp, heatmap of the trajectory of the fish over the time (2 dimension) and
        # channels (1. dimension)
        power_pancake = np.nanmax(power_cube, axis=0)
        Ctime = realtime[~np.isnan(power_pancake[:, 0])]
        all_Ctime = realtime[int(np.where(realtime == Ctime[0])[0]):int(np.where(realtime == Ctime[-1])[0] + 1)]
        Cpancake = power_pancake[~np.isnan(power_pancake[:, 0])]

        try:
            ipp = np.array(list(map(lambda x: np.interp(all_Ctime, Ctime, x), Cpancake.T))).T
        except:
            continue
        all_ipp.append(ipp)
        all_Ctime_gesamt.append(all_Ctime)

        ###############################################################################################################
        # trajectories (channel) of the fish over time by using the maximum power on each time point
        max_ch = np.argmax(ipp, axis=1)
        all_max_ch.append(max_ch)

        ###############################################################################################################
        # power means of each fish
        power_means.append(np.mean(ipp[range(len(max_ch)), max_ch]))

        ###############################################################################################################
        # all x ticks of the fish in datetime format
        datetime_v = []
        # hour minutes seconds vector
        hms = []
        dt = datetime.datetime.strptime(filename[-5:], '%H_%M')
        for i in list(np.arange(int(np.where(realtime == Ctime[0])[0]), int(np.where(realtime == Ctime[-1])[0] + 1))):
            current_time = dt + datetime.timedelta(seconds=realtime[i])  # converts the time point in datetime points
            datetime_v.append(current_time)
            hms.append(current_time.hour * 60 * 60 + current_time.minute * 60 + current_time.second +
                       float('0.' + str(current_time.microsecond)))  # converts the time points into a second format

        x_tickses = mdates.date2num(datetime_v)
        all_xtickses.append(x_tickses)
        all_hms.append(np.array(hms))

        ###############################################################################################################
        # running mean and std
        # run_mean, run_std = hf.running_3binsizes(ipp, sampling_rate)
        # TODO: not absolute max, instead average of the three strongest channel
        # bin_size = int(np.floor((Ctime[-1]-Ctime[0])/50))
        #
        # max_ch_mean = np.full([len(max_ch)], np.nan)
        # for i in range(int(len(max_ch)-np.floor(bin_size))):
        #     max_ch_mean[int(i + np.floor(bin_size / 2))] = np.mean(max_ch[i:bin_size + i])

        # all_run_mean.append(run_mean)
        # all_run_std.append(run_std)

    ###############################################################################################################
    # threshold for the activity -- OUT DATED!!!!!!
    # rs = np.array(all_run_std)
    #
    # for idx in range(3):
    #     rs1 = np.hstack(rs[:, idx])
    #     rs1 = rs1[~np.isnan(rs1)]
    #     rs1 = rs1[rs1>0]
    #     print(np.percentile(rs1, [5, 95, 50]))
    #     thresholds[idx] = np.percentile(rs1, 50)
    ###################################################################################################################
    # save
    ###################################################################################################################
    if filename not in os.listdir('../data/'):
        os.mkdir('../data/'+filename)

    np.save('../data/' + filename + '/all_Ctime_v.npy', all_Ctime_gesamt)
    np.save('../data/' + filename + '/all_max_ch.npy', all_max_ch)
    # np.save('../data/' + filename + '/all_run_mean.npy', all_run_mean)
    # np.save('../data/' + filename + '/all_run_std.npy', all_run_std)
    # np.save('../data/' + filename + '/thresholds.npy', thresholds)
    np.save('../data/' + filename + '/all_xtickses.npy', all_xtickses)
    np.save('../data/' + filename + '/all_ipp.npy', all_ipp)
    np.save('../data/' + filename + '/power_means.npy', power_means)
    np.save('../data/' + filename + '/all_hms.npy', all_hms)


    embed()
    quit()