import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from IPython import embed
from scipy import stats, optimize
import pandas as pd
import math
import os
from IPython import embed

from eventdetection import threshold_crossings, merge_events
import helper_functions as hf
from params import *
from statisitic_functions import significance_bar, cohen_d
import itertools


def get_recording_number_in_time_bins(time_bins):
    """
    Calculates the number of the recordings in the time bins

    :param time_bins: numpy array with borders of the time bins
    :return: time_bins_recording: numpy array with the number of recordings to that specific time bin
    """
    # variables
    time_bins_recordings = np.zeros(len(time_bins) - 1)

    # load data
    for index, filename_idx in enumerate([0, 1, 2, 3]):
        filename = sorted(os.listdir('../data/'))[filename_idx]
        time_points = np.load('../data/' + filename + '/all_hms.npy', allow_pickle=True)

        # in which bins is this recording, fill time_bins_recordings
        unique_time_points = np.unique(np.hstack(time_points))
        for idx, tb in enumerate(time_bins[:-1]):
            if np.any((unique_time_points >= tb) & (unique_time_points <= time_bins[idx + 1])):
                time_bins_recordings[idx] += 1

    return time_bins_recordings


def func(x, a, tau, c):
    return a * np.exp(-x / tau) + c


def calc_movement(cbf, i):
    movement = cbf[0, :, i] + cbf[1, :, i]
    movement[np.isnan(movement)] = 0
    re_mov = cbf[0, :, i] - cbf[1, :, i]
    re_mov[np.isnan(re_mov)] = 0

    return movement, re_mov


if __name__ == '__main__':
    ###################################################################################################################
    # parameter and variables
    # plot params
    inch = 2.54

    c = 0
    cat_v1 = [0, 0, 750, 0]
    cat_v2 = [750, 750, 1000, 1000]
    cat_n = ['Eigenmannia', 'Apteronotus', 'Apteronotus']

    # time
    # time_bins 5 min
    time_factor = 60 * 60
    # tb2 = np.arange(0, 24 * time_factor + 1, 2)
    tb5 = np.arange(0, 24 * time_factor + 1, 5)
    # tb10 = np.arange(0, 24 * time_factor + 1, 10)
    tb15 = np.arange(0, 24 * time_factor + 1, 15)
    # tb30 = np.arange(0, 24 * time_factor + 1, 30)
    tb60 = np.arange(0, 24 * time_factor + 1, 60)
    tb150 = np.arange(0, 24 * time_factor + 1, 150)
    # tb180 = np.arange(0, 24 * time_factor + 1, 180)
    tb300 = np.arange(0, 24 * time_factor + 1, 300)


    # time_edges = np.array([4.5, 6.5, 16.5, 18.5]) * time_factor
    # day = time_bins[:-1][(time_bins[:-1] >= time_edges[1]) & (time_bins[:-1] <= time_edges[2])]
    # dusk = time_bins[:-1][(time_bins[:-1] >= time_edges[2]) & (time_bins[:-1] <= time_edges[3])]
    # night = time_bins[:-1][(time_bins[:-1] <= time_edges[0]) | (time_bins[:-1] >= time_edges[3])]
    # dawn = time_bins[:-1][(time_bins[:-1] >= time_edges[0]) & (time_bins[:-1] <= time_edges[1])]

    ###################################################################################################################
    # load data
    ###################################################################################################################
    # load all the data of one day
    # cbf2 = np.load('../data/cbf2.npy', allow_pickle=True)
    cbf5 = np.load('../data/cbf5.npy', allow_pickle=True)
    # cbf10 = np.load('../data/cbf10.npy', allow_pickle=True)
    cbf15 = np.load('../data/cbf15.npy', allow_pickle=True)
    # cbf30 = np.load('../data/cbf30.npy', allow_pickle=True)
    cbf60 = np.load('../data/cbf60.npy', allow_pickle=True)
    cbf150 = np.load('../data/cbf150.npy', allow_pickle=True)
    # cbf180 = np.load('../data/cbf180.npy', allow_pickle=True)
    cbf300 = np.load('../data/cbf300.npy', allow_pickle=True)

    stl = np.load('../data/stl.npy', allow_pickle=True)
    names = np.load('../data/n.npy', allow_pickle=True)
    freq = np.load('../data/f.npy', allow_pickle=True)

    trajectories = np.load('../data/trajectories.npy', allow_pickle=True)
    trajec_x = np.load('../data/trajec_x.npy', allow_pickle=True)

    ###############################################################################################################
    # variables
    for index, filename_idx in enumerate([0]):
        filename = sorted(os.listdir('../data/'))[filename_idx]
        all_Ctime_v = np.load('../data/' + filename + '/all_Ctime_v.npy', allow_pickle=True)
        sampling_rate = 1 / np.diff(all_Ctime_v[0])[0]  # in sec

    cbf_counter = 0
    ###################################################################################################################
    # analysis
    for i in range(len(trajectories)):
        if names[i] == 'unknown':
            continue


        # mov2, re_mov2 = calc_movement(cbf2, cbf_counter)
        mov5, re_mov5 = calc_movement(cbf5, cbf_counter)
        # mov10, re_mov10 = calc_movement(cbf10, cbf_counter)
        mov15, re_mov15 = calc_movement(cbf15, cbf_counter)
        # mov30, re_mov30 = calc_movement(cbf30, cbf_counter)
        mov60, re_mov60 = calc_movement(cbf60, cbf_counter)
        mov150, re_mov150 = calc_movement(cbf150, cbf_counter)
        # mov180, re_mov180 = calc_movement(cbf180, cbf_counter)
        mov300, re_mov300 = calc_movement(cbf300, cbf_counter)
        cbf_counter += 1

        trajec = trajectories[i]
        t_x = trajec_x[i]

        fig = plt.figure(constrained_layout=True, figsize=[20 / inch, 26 / inch])
        gs = gridspec.GridSpec(ncols=2, nrows=6, figure=fig, hspace=0.01, wspace=0.01,
                               height_ratios=[1, 1, 1, 1, 1, 1], width_ratios=[4,1],left=0.1, bottom=0.15, right=0.95,
                               top=0.95)

        ax0 = fig.add_subplot(gs[0, 0])
        ax1 = fig.add_subplot(gs[1, 0], sharex=ax0)
        ax2 = fig.add_subplot(gs[2, 0], sharex=ax0)
        ax3 = fig.add_subplot(gs[3, 0], sharex=ax0)
        ax4 = fig.add_subplot(gs[4, 0], sharex=ax0)
        ax5 = fig.add_subplot(gs[5, 0], sharex=ax0)
        # ax6 = fig.add_subplot(gs[6, 0], sharex=ax0)
        ax11 = fig.add_subplot(gs[1, 1])
        ax21 = fig.add_subplot(gs[2, 1])
        ax31 = fig.add_subplot(gs[3, 1])
        ax41 = fig.add_subplot(gs[4, 1])
        ax51 = fig.add_subplot(gs[5, 1])
        # ax61 = fig.add_subplot(gs[6, 1])

        ax0.plot(t_x/60/60, trajec)
        # ax1.plot(tb2[:-1]/60/60, mov2)
        ax1.plot(tb5[:-1]/60/60, mov5)
        # ax3.plot(tb10[:-1]/60/60, mov10)
        ax2.plot(tb15[:-1]/60/60, mov15)
        # ax3.plot(tb30[:-1]/60/60, mov30)
        ax3.plot(tb60[:-1]/60/60, mov60)
        ax4.plot(tb150[:-1]/60/60, mov150)
        ax5.plot(tb300[:-1]/60/60, mov300)

        # ax11.hist(mov2, bins=np.linspace(1,np.max(mov2),int(np.max(mov2))))
        ax11.hist(mov5, bins=np.linspace(1,np.max(mov5)+1,int(np.max(mov5)+1)))
        # ax31.hist(mov10, bins=np.linspace(1,np.max(mov10),int(np.max(mov10))))
        ax21.hist(mov15, bins=np.linspace(1,np.max(mov15)+1,int(np.max(mov15)+1)))
        # ax31.hist(mov30, bins=np.linspace(1,np.max(mov30),int(np.max(mov30))))
        ax31.hist(mov60, bins=np.linspace(1,np.max(mov60)+1,int(np.max(mov60)+1)))
        ax41.hist(mov150, bins=np.linspace(1,np.max(mov150)+1,int(np.max(mov150)+1)))
        ax51.hist(mov300, bins=np.linspace(1,np.max(mov300)+1,int(np.max(mov300)+1)))
        # ax7.hist(mov2)

        tag = ['trajectory', '5', '15', '60', '150', '300']
        for idx, ax in enumerate([ax0, ax1, ax2, ax3, ax4, ax5]):
            xl_min=np.min(t_x)/60/60
            xl_max=np.max(t_x)/60/60
            ax.set_xlim([xl_min ,xl_max])
            ax.text(0.01, 0.7, tag[idx], transform=ax.transAxes, fontsize='small')
            if ax != ax0:
                ax.set_ylabel('n')

        ax0.set_ylim([0,15])
        ax0.invert_yaxis()
        ax0.set_ylabel('electrode')
        ax5.set_xlabel('Time [h]')
        fig.suptitle('EODf '+str(np.round(freq[i],2))+' '+names[i], fontsize=12)
        # embed()
        # quit()
        fig.savefig('../../../jan_plots/trajec'+str(i)+'.pdf')

        plt.close()




    # ###############################################################################################################
    # # roll time axis
    # start = []
    # stop = []
    # for j in range(len(roaming_events)):
    #     start.extend(roaming_events[j][0])
    #     stop.extend(roaming_events[j][1])
    #
    # N_rec_time_bins = get_recording_number_in_time_bins(time_bins[::int((60 / bin_len) * 60)])
    #
    # # rolled time axis for nicer plot midnight in the middle start noon
    # N_start, bin_edges = np.histogram(np.array(start) * 5, bins=time_bins[::int((60 / bin_len) * 60)])
    # N_stop, bin_edges2 = np.histogram(np.array(stop) * 5, bins=time_bins[::int((60 / bin_len) * 60)])
    # rolled_start = np.roll(N_start / N_rec_time_bins, int(len(N_start) / 2))
    # rolled_stop = np.roll(N_stop / N_rec_time_bins, int(len(N_stop) / 2))
    # rolled_bins = (bin_edges[:-1] / time_factor) + 0.5
    #
    # ###############################################################################################################
    # # figure 1: max_channel_changes per time zone and per duration of the roaming event
    # fig = plt.figure(constrained_layout=True, figsize=[15 / inch, 14 / inch])
    # gs = gridspec.GridSpec(ncols=6, nrows=3, figure=fig, hspace=0.01, wspace=0.01,
    #                        height_ratios=[1, 1, 1], width_ratios=[1, 1, 1, 1, 1, 1], left=0.1, bottom=0.15, right=0.95,
    #                        top=0.95)
    #
    # ax0 = fig.add_subplot(gs[0, :])
    # ax1 = fig.add_subplot(gs[1, :3])
    # ax2 = fig.add_subplot(gs[1, 3:], sharex=ax1)
    # ax3 = fig.add_subplot(gs[2, :2], sharey=ax2)
    # ax4 = fig.add_subplot(gs[2, 2:4], sharey=ax2)
    # ax5 = fig.add_subplot(gs[2, 4:])
    #
    # # axins = inset_axes(ax1, width='30%', height='60%')
    #
    # # bar plot
    # ax0.bar(rolled_bins, rolled_start, color=color2[4])
    # print('bar plot')
    # print('day: mean ', np.round(np.mean([rolled_start[:6], rolled_start[18:]]), 2),
    #       ' std: ', np.round(np.std([rolled_start[:6], rolled_start[18:]]), 2))
    #
    # print('night: mean ', np.round(np.mean(rolled_start[6:18]), 2),
    #       ' std: ', np.round(np.std(rolled_start[6:18]), 2))
    #
    # ax0.plot([16.5, 6.5], [20, 20], color=color_diffdays[0], lw=7)
    # ax0.plot([16.5, 18.5], [20, 20], color=color_diffdays[3], lw=7)
    # ax0.plot([4.5, 6.5], [20, 20], color=color_diffdays[3], lw=7)
    #
    # ###############################################################################################################
    # # curve_fit: tau, std, n
    # curvefit_stat = []
    #
    # xdata = np.linspace(0.0, 10., 500)
    # y_speeds = []
    # for plot_zone, color_zone, day_zone, pos_zone in \
    #         zip([day, dusk, night, dawn], [6, 1, 4, 0], ['day', 'dusk', 'night', 'dawn'], [1, 2, 3, 4]):
    #
    #     # boxplot ax1
    #     props_e = dict(linewidth=2, color=color2[color_zone])
    #     bp = ax1.boxplot(dauer[np.in1d(wann * 5, plot_zone)], positions=[pos_zone], widths=0.7,
    #                      showfliers=False, vert=False,
    #                      boxprops=props_e, medianprops=props_e, capprops=props_e, whiskerprops=props_e)
    #
    #     x_n = [item.get_xdata() for item in bp['whiskers']][1][1]
    #     n = len(dauer[np.in1d(wann * 5, plot_zone)])
    #     ax1.text(x_n + 2, pos_zone, str(n), ha='left', va='center')
    #     print('dauer: ', day_zone, np.median(dauer[np.in1d(wann * 5, plot_zone)]),
    #           ' 25, 75: ', np.percentile(dauer[np.in1d(wann * 5, plot_zone)], [25, 75]))
    #
    #     # curve fit
    #     x_dauer = dauer[dauer <= 10][np.in1d(wann[dauer <= 10] * 5, plot_zone)]
    #     y_speed = speeds[dauer <= 10][np.in1d(wann[dauer <= 10] * 5, plot_zone)]
    #     y_speeds.append(y_speed)
    #
    #     popt, pcov = optimize.curve_fit(func, x_dauer, y_speed)
    #     perr = np.sqrt(np.diag(pcov))
    #     print(day_zone, popt, 'perr', perr[1])
    #     curvefit_stat.append(np.array([popt[1], perr[1], n]))
    #
    #     # plot dauer vs speed
    #     ax2.plot(x_dauer, y_speed, 'o', alpha=0.3, color=color2[color_zone])
    #
    #     ax3.plot(x_dauer, y_speed, 'o', alpha=0.3, color=color2[color_zone])
    #
    #     # plot curve fit
    #     ax4.plot(xdata, func(xdata, *popt), '-', color=color2[color_zone], label=day_zone)
    #     ax4.set_ylim(ax2.get_ylim())
    #
    # curvefit_stat = np.array(curvefit_stat)
    # # plot std of tau
    # ax5.bar([0, 1, 2, 3], curvefit_stat[:, 0], yerr=curvefit_stat[:, 1], color=color2[4])
    #
    # ###############################################################################################################
    # # statistic
    # day_group = [day, dusk, night, dawn]
    # for subset in itertools.combinations([0, 1, 2, 3], 2):
    #     mean1, std1, n1 = curvefit_stat[subset[0]]
    #     mean2, std2, n2 = curvefit_stat[subset[1]]
    #     t, p = stats.ttest_ind_from_stats(mean1, std1, n1, mean2, std2, n2)
    #     d = cohen_d(y_speeds[subset[0]], y_speeds[subset[1]])
    #     print(['day', 'dusk', 'night', 'dawn'][subset[0]], ['day', 'dusk', 'night', 'dawn'][subset[1]], 't: ',
    #           np.round(t, 2), 'p: ', np.round(p, 4), 'd: ', d)
    #
    #     print(stats.mannwhitneyu(dauer[dauer <= 100][np.in1d(wann[dauer <= 100] * 5, day_group[subset[0]])],
    #                        dauer[dauer <= 100][np.in1d(wann[dauer <= 100] * 5, day_group[subset[1]])]))
    #     if subset[0] == 0 and subset[1] == 2:
    #         significance_bar(ax5, p, None, subset[0], subset[1], 4.)
    #
    # ###############################################################################################################
    # # labels
    # ax0.set_ylabel('# Roaming Events', fontsize=fs)
    # ax0.set_xticks([0, 6, 12, 18, 24])
    # ax0.set_xticklabels(['12:00', '18:00', '00:00', '06:00', '12:00'])
    # ax0.set_xlabel('Time', fontsize=fs)
    #
    # ax1.set_yticks([1, 2, 3, 4])
    # ax1.set_yticklabels(['day', 'dusk', 'night', 'dawn'])
    # ax1.set_xlabel('Duration [min]', fontsize=fs)
    # ax1.invert_yaxis()
    #
    # ax2.set_xlabel('Duration [min]', fontsize=fs)
    # ax2.set_ylabel('Speed [m/min]', fontsize=fs)
    # ax2.set_ylim([0, 27])
    #
    # ax3.set_ylabel('Speed [m/min]', fontsize=fs)
    # ax3.set_xlabel('Duration [min]', fontsize=fs)
    # ax3.set_xlim([0, 10])
    #
    # ax4.set_xlabel('Duration [min]', fontsize=fs)
    # ax4.set_xlim([0, 10])
    #
    # ax5.set_xticks([0, 1, 2, 3])
    # ax5.set_xticklabels(['day', 'dusk', 'night', 'dawn'], rotation=45)
    # ax5.set_ylabel(r'$\tau$')
    #
    # tagx = [-0.05, -0.07, -0.07, -0.17, -0.17, -0.17]
    # for idx, ax in enumerate([ax0, ax1, ax2, ax3, ax4, ax5]):
    #     ax.make_nice_ax()
    #     ax.text(tagx[idx], 1.05, chr(ord('A') + idx), transform=ax.transAxes, fontsize='large')
    #
    # # fig.align_ylabels()
    # # fig.savefig(save_path + 'roaming_events.pdf')
    # # fig.savefig(save_path_pres + 'roaming_events.pdf')
    #
    # ###############################################################################################################
    # # figure 2:
    # linregress_stat = []
    # fig2 = plt.figure(constrained_layout=True, figsize=[15 / inch, 10 / inch])
    # gs = gridspec.GridSpec(ncols=1, nrows=2, figure=fig2, hspace=0.05, wspace=0.0,
    #                        height_ratios=[1, 2], left=0.1, bottom=0.15, right=0.95, top=0.95)
    #
    # ax21 = fig2.add_subplot(gs[0, 0])
    # ax23 = fig2.add_subplot(gs[1, 0])
    #
    # for plot_zone, color_zone, day_zone, bar_pos, pos_zone in \
    #         zip([day, dusk, night, dawn], [6, 1, 4, 0], ['day', 'dusk', 'night', 'dawn'], [-0.3, -0.1, 0.1, 0.3],
    #             [0, 1, 2, 3]):
    #     # pdf
    #     N_roam, bin_roam = np.histogram(roam_dist[np.in1d(wann * 5, plot_zone)], bins=np.linspace(0, 15, 16))
    #     N_roam = N_roam / np.sum(N_roam) / (bin_roam[1] - bin_roam[0])
    #     ax21.plot(bin_roam[:-1], N_roam, color=color2[color_zone], label=day_zone)
    #     ax21.set_xlabel('Distance [m]')
    #     ax21.set_ylabel('PDF')
    #     ax21.set_xlim([1, 15])
    #
    #     # duration vs distance
    #     ax23.plot(dauer[np.in1d(wann * 5, plot_zone)], roam_dist[np.in1d(wann * 5, plot_zone)], 'o',
    #               color=color2[color_zone], alpha=0.3)
    #     res = stats.linregress(dauer[np.in1d(wann * 5, plot_zone)], roam_dist[np.in1d(wann * 5, plot_zone)])
    #     print(day_zone, res.slope)
    #     linregress_stat.append(np.array([res.slope, res.stderr, len(dauer[np.in1d(wann * 5, plot_zone)])]))
    #     ax23.set_xlabel('Duration [min]')
    #     ax23.set_ylabel('Distance [m]')
    #     ax23.set_xlim([0, 100])
    #
    # print('linregress')
    # for subset in itertools.combinations([0, 1, 2, 3], 2):
    #     mean1, std1, n1 = linregress_stat[subset[0]]
    #     mean2, std2, n2 = linregress_stat[subset[1]]
    #     t, p = stats.ttest_ind_from_stats(mean1, std1, n1, mean2, std2, n2)
    #     d = cohen_d(y_speeds[subset[0]], y_speeds[subset[1]])
    # #     print(['day', 'dusk', 'night', 'dawn'][subset[0]], ['day', 'dusk', 'night', 'dawn'][subset[1]], 't: ',
    # #           np.round(t, 2), 'p: ', np.round(p, 4), 'd: ', d)
    # # print(np.round(0.05 / 6, 4))
    #
    # for axis in [ax21, ax23]:
    #     axis.make_nice_ax()
    #
    # ax21.legend(loc='best', bbox_to_anchor=(0.5, 0.7, 0.5, 0.5), ncol=2)
    #
    # fig2.savefig(save_path_pres + 'roaming_distance.pdf')
    # fig2.savefig(save_path + 'roaming_distance.pdf')
    #
    # plt.show()
    #
    # # df = pd.DataFrame({'duration': dauer, 'speed': speeds, 'distance': roam_dist})
    # embed()
    # quit()