import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from IPython import embed
from scipy import stats
import math
import os
from IPython import embed
import helper_functions as hf
from params import *


if __name__ == '__main__':

    ###################################################################################################################
    # parameter and variables
    # plot params
    inch = 2.45
    save_path = '../../thesis/Figures/Results/'

    # lists
    afx = []    # all_flat_x

    ###################################################################################################################
    # load data
    for filename_idx in [0, 1, 2, 3]:
        filename = sorted(os.listdir('../data/'))[filename_idx]
        all_xticks = np.load('../data/' + filename + '/all_xtickses.npy', allow_pickle=True)

        # lists
        afx.extend(np.unique(np.hstack(all_xticks)))
    afx = sorted(np.unique(afx))
    c = 0
    ###################################################################################################################
    # make alm: all location matrix
    alm = np.full([len(afx), 16], 0)
    for filename_idx in [0, 1, 2, 3]:
        filename = sorted(os.listdir('../data/'))[filename_idx]
        all_max_ch_means = np.load('../data/' + filename + '/all_max_ch.npy', allow_pickle=True)
        all_xticks = np.load('../data/' + filename + '/all_xtickses.npy', allow_pickle=True)
        power_means = np.load('../data/' + filename + '/power_means.npy', allow_pickle=True)
        names = np.load('../data/' + filename + '/fish_species.npy', allow_pickle=True)

        for fish_number in range(len(power_means)):
            if power_means[fish_number] >= -90 and names[fish_number] != 'unknown':
                c += 1
                x_tickses = all_xticks[fish_number]
                max_ch_mean = all_max_ch_means[fish_number]

                alm[np.flatnonzero(np.isin(afx, x_tickses)), max_ch_mean] += 1

    hist = np.sum(alm, axis=0)/len(alm)
    # hist = np.mean(alm, axis=0) #/len(alm)
    # hist_err = np.std(alm, axis=0) #/len(alm)

    ###################################################################################################################
    # figure
    # fig1, [ax1, ax11] = plt.subplots(1, 2, figsize=(16 / inch, 10 / inch))
    # fig1.subplots_adjust(left=0.12, bottom=0.15, right=0.99, top=0.99, wspace=0.0)

    fig1 = plt.figure(figsize=[13 / inch, 10 / inch])
    spec = gridspec.GridSpec(ncols=2, nrows=1, figure=fig1, hspace=0.5, wspace=0.05,
                             width_ratios=[7,1], left=0.12, bottom=0.15, right=0.99, top=0.95)
    ax1 = fig1.add_subplot(spec[0, 0])
    ax11 = fig1.add_subplot(spec[0, 1])
    ###################################################################################################################
    # plot
    ax1.imshow(alm[::20].T[::-1], vmin=0.0, vmax=7.0, aspect='auto', interpolation='gaussian',
               extent=[afx[0], afx[-1], -0.5, 15.5])
    ax1.beautimechannelaxis()
    ax1.timeaxis()
    # fig1.autofmt_xdate()

    ax11.barh(np.arange(0,16), hist, color=color2[4])
    # ax11.barh(np.arange(0,16), hist, color=color2[4], xerr=hist_err)
    ax11.make_nice_ax()

    ax11.set_ylim(-0.5,15.5)
    ax11.set_yticks([])
    ax11.set_xticks([])
    ax11.invert_yaxis()

    fig1.savefig(save_path+'all_trajectories.pdf')

    x = np.array(afx)
    x1 = np.where((x>=datetime_box[0])&(x<=datetime_box[2]))[0]
    x2 = np.where((x>=datetime_box[4])&(x<=datetime_box[5]))[0]
    y = np.concatenate((alm[x1],alm[x2]))
    y_a=np.sum(y, axis=0)/len(y)

    x3 = np.where((x>=datetime_box[2])&(x<=datetime_box[4]))[0]
    y_b = np.sum(alm[x3], axis=0)/len(alm[x3])

    fig2 = plt.figure(figsize=[13 / inch, 10 / inch])
    spec = gridspec.GridSpec(ncols=2, nrows=1, figure=fig2, hspace=0.5, wspace=0.05,
                             width_ratios=[1, 1], left=0.12, bottom=0.15, right=0.99, top=0.95)
    ax2 = fig2.add_subplot(spec[0, 0])
    ax22 = fig2.add_subplot(spec[0, 1])
    ###################################################################################################################
    # plot
    ax2.barh(np.arange(1, 17), y_a, color=color2[4])
    ax22.barh(np.arange(1, 17), y_b, color=color2[4])

    for ax in [ax2, ax22]:
        ax.make_nice_ax()
        ax.set_ylim(-0.5, 15.5)
        ax.set_ylim(0, 6.0)
        # ax.set_xticks([])
        ax.set_yticks([0, 1, 2, 3, 4, 5, 6, 7, 7.5, 8, 9, 10, 11, 12, 13, 14, 15])
        ax.set_yticklabels(['1', '', '3', '', '5', '', '7', '', 'g', '', '10', '', '12', '', '14', '', '16'], fontsize=9)
        ax.invert_yaxis()
    ax22.set_yticklabels([])

    plt.show()

    embed()
    quit()

    ###################################################################################################################
    # figure 2: each day on its own
    # for filename_idx in [1, 4, 6]:
    #     filename = sorted(os.listdir('../../../data/'))[filename_idx]
    #     all_max_ch_means = np.load('../data/' + filename + '/all_max_ch.npy', allow_pickle=True)
    #     all_xticks = np.load('../data/' + filename + '/all_xtickses.npy', allow_pickle=True)
    #     power_means = np.load('../data/' + filename + '/power_means.npy', allow_pickle=True)
    #
    #     # lists
    #     flat_x = np.unique(np.hstack(all_xticks))
    #     afx.extend(flat_x)
    #     location_matrix = np.full([len(flat_x), 16], 0)
    #
    #     for fish_number in range(len(power_means)):
    #         # if power_means[fish_number] >= -65:
    #         x_tickses = all_xticks[fish_number]
    #         max_ch_mean = all_max_ch_means[fish_number]
    #
    #         location_matrix[np.flatnonzero(np.isin(flat_x, x_tickses)), max_ch_mean] += 1
    #
    #     fig2 = plt.figure(figsize=[16 / inch, 12 / inch])
    #     spec = gridspec.GridSpec(ncols=1, nrows=1, figure=fig2, hspace=0.5, wspace=0.5)
    #     ax2 = fig2.add_subplot(spec[0, 0])
    #
    #     ax2.imshow(location_matrix[::20].T[::-1],  aspect='auto', interpolation='gaussian', cmap='jet',
    #                extent=[flat_x[0], flat_x[-1], 0, 15])
    #     ax2.beautimechannelaxis()
    #     ax2.timeaxis()
    #     fig2.autofmt_xdate()
    #
    # plt.show()
    # embed()