#from matplotlib.colors import LinearSegmentedColormap
#from plot_eod_chirp  import find_times
import matplotlib.pyplot as plt
import numpy as np
from IPython import embed
#embed()
import matplotlib as matplotlib
import math
import scipy.integrate as integrate
from scipy import signal
from scipy.interpolate import interp1d
from scipy.interpolate import CubicSpline
import scipy as sp
import pickle
from scipy.spatial import distance
from myfunctions import *
import time
from matplotlib import gridspec
#from matplotlib_scalebar.scalebar import ScaleBar
import matplotlib.mlab as ml
import scipy.integrate as si
import pandas as pd
from myfunctions import default_settings
from functionssimulation import single_stim
from plot_eod_chirp  import rectify,find_dev, conv,global_maxima,integrate_chirp,find_periods,find_lm,pl_eods
from plot_eod_chirp import find_beats,snip, power_func
import os


def plot_power(beats,  results,  win, deviation_s, sigma, sampling, d_ms, beat_corr, size, phase_zero, delta_t, a_fr, a_fe, eod_fr, eod_fe, deviation, show_figure = False, plot_dist = False, save = False,bef_c = 1.1,aft_c =-0.1, share = False):
    #embed()
    plt.rcParams['figure.figsize'] = (6.27, 8)
    plt.rcParams["legend.frameon"] = False
    colors = ['black','blue','purple','magenta','pink','orange', 'brown', 'red', 'green','lime']
    fig, ax = plt.subplots(nrows=len(results), ncols=2, sharex=True)
    all_max = [[]] * len(results)
    for i in range(len(results)):
        ax[i, 0].set_ylabel(results['type'][i], rotation=0, labelpad=40, color=colors[i])
        ax[i, 0].plot(beats / eod_fr + 1, np.array(results['result_frequency' ][i]) / eod_fr + 1, color=colors[i])
        # plt.title(results['type'][i])
        #ax[i, 1].plot(beats / eod_fr + 1, np.array(results['amp'][i]), color=colors[i])
        ax[i, 1].plot(beats / eod_fr + 1, np.array(results['result_amplitude_max'][i]), color=colors[i])
        #ax[i, 2].plot(beats / eod_fr + 1, np.array(results['amp'][i]), color=colors[i])
        ax[i,0].set_ylim([1,1.6])
        all_max[i] = np.max(np.array(results['result_amplitude_max'][i]))
    #for i in range(len(results)):
    #    ax[i, 2].set_ylim([0, np.max(all_max)])
    plt.subplots_adjust(left=0.25)
    ii, jj = np.shape(ax)
    ax[0, 0].set_title('Most popular frequency')
    ax[0, 1].set_title('Modulation depth')
    #ax[0, 2].set_title('Modulation depth (same scale)')
    for i in range(ii):

        for j in range(jj):
            ax[-1, j].set_xlabel('EOD multiples')
            ax[i, j].spines['right'].set_visible(False)
            ax[i, j].spines['top'].set_visible(False)

    def onclick(event):
        #embed()
        eod_fe = [(event.xdata-1)*eod_fr]
        nfft = 4096
        e = 0
        #sampling = 121000
        left_b = [bef_c * sampling] * len(beat_corr)
        right_b = [aft_c * sampling] * len(beat_corr)
        p = 0
        s = 0
        d = 0
        time_fish,sampled, cut, cubed, time_b, conv_b, am_corr_b, peaks_interp_b, maxima_interp_b, am_corr_ds_b, am_df_ds_b, am_df_b, eod_overlayed_chirp = snip(
            left_b, right_b, e, e,
            sampling, deviation_s,
            d, eod_fr, a_fr, eod_fe,
            phase_zero, p, size, s,
            sigma, a_fe, deviation,
            beat_corr, chirp=False)
        #cubed, time_b, conv_b, am_corr_b, peaks_interp_b, maxima_interp_b, am_corr_ds_b, am_df_ds_b, am_df_b, eod_overlayed_chirp = onclick_func(e,nfft,beats,  results,        disc, win, deviation_s, sigma,  d_ms, beat_corr, size, phase_zero, delta_t, a_fr, a_fe, eod_fr, eod_fe, deviation, show_figure = show_figure, plot_dist = plot_dist, save = save,bef_c = bef_c,aft_c =aft_c, sampling = sampling)

        nfft = 4096

        results = [[]] * 1
        name = ['cubed']
        var = [cubed]
        var = [maxima_interp_b]
        samp = [sampling]
        nfft = int((4096 * samp[0] / 10000) * 2)
        i = 0
        pp, f = ml.psd(var[i] - np.mean(var[i]), Fs=samp[i], NFFT=nfft,
                       noverlap=nfft / 2)
        plt.figure()
        plt.subplot(1,2,1)
        plt.plot(f, pp)
        plt.xlim([0,2000])
        #plt.subplot(1,3,2)
        #plt.plot(time_b, cubed)
        plt.subplot(1,2,2)
        plt.plot(time_b, maxima_interp_b)
        plt.show()
    if share == True:
        cid = fig.canvas.mpl_connect('button_press_event', onclick)
    return fig


def onclick_func(e,nfft, beats,  results,        disc, win, deviation_s, sigma,  d_ms, beat_corr, size, phase_zero, delta_t, a_fr, a_fe, eod_fr, eod_fe, deviation, show_figure = False, plot_dist = False, save = False,bef_c = 1.1,aft_c =-0.1,sampling = 100000):
    left_b = [bef_c * sampling] * len(beat_corr)
    right_b = [aft_c * sampling] * len(beat_corr)
    p = 0
    s = 0
    time_fish,sampled, cut, cubed, time_b, conv_b, am_corr_b, peaks_interp_b, maxima_interp_b, am_corr_ds_b, am_df_ds_b, am_df_b, eod_overlayed_chirp = snip(
        left_b, right_b, e, e,
        sampling, deviation_s,
        d, eod_fr, a_fr, eod_fe,
        phase_zero, p, size, s,
        sigma, a_fe, deviation,
        beat_corr, chirp=False)
    return pp, f, cubed, cubed, time_b, conv_b, am_corr_b, peaks_interp_b, maxima_interp_b, am_corr_ds_b, am_df_ds_b, am_df_b, eod_overlayed_chirp





delta_t = 0.014
interest_interval = delta_t * 1.2
sigma = delta_t / math.sqrt((2 * math.log(10)))  # width of the chirp
phase_zero = np.arange(0,2*np.pi,2*np.pi/10) #phase_zero = [0]  # phase when the chirp occured (vary later) / zero at the peak o a beat cycle
#eod_fr = 637# eod fish reciever
#eod_fr = 537
#eod_fr = 1435
load = False
counter = 0
results = []
results = pd.DataFrame(results)
if load == True:
    #results = []
    eod_fr = [500,734,820,1000,1492]
    #eod_fr = [500]
    for ee in range(len(eod_fr)):
        factor = 200
        sampling_fish = 500
        step = 500
        win = 'w2'
        d = 1
        x = [ 1.5]#x = [ 1.5, 2.5,0.5,]
        time_range = 200 * delta_t
        #sampling = 112345
        sampling = [83425,98232,100000,112683]
        #sampling = [8425, 9232]
        for s in range(len(sampling)):
            deviation_ms, deviation_s, deviation_dp = find_dev(x, sampling[s])
            start = 5
            end = 3500
            # step = 25
            step = [30, 25, 10]
            #step = [250]
            for ss in range(len(step)):
                eod_fe, beat_corr, beats = find_beats(start,end,step[ss],eod_fr[ee])
                delta_t = 1
                load = True
                size = [120]
                a_fr = 1
                a_fe = [0.5, 0.3, 0.2]
                #a_fe = [0.5]
                for a in range(len(a_fe)):

                    bef_c = -2
                    aft_c = -1
                    load = True
                    #if counter != 0:
                    #    results = results1*1
                    results1 = power_func(a_fr = a_fr, a_fe = a_fe[a], eod_fr = eod_fr[ee], eod_fe = eod_fe, win = 'w2', deviation_s = deviation_s, sigma = sigma,  sampling = sampling[s],  deviation_ms = deviation_ms, beat_corr = beat_corr, size = size,phase_zero =  [phase_zero[0]], delta_t = delta_t,deviation_dp = deviation_dp, show_figure = True, plot_dist = False, save = False,bef_c = bef_c,aft_c = aft_c)
                    results1['sampling'] = sampling[s]
                    results1['eod_fr'] = eod_fr[ee]
                    results1['amplitude'] = a_fe[a]
                    results1['step'] = step[ss]
                    #if counter == 0:
                    #    results = results1
                    #elif counter != 0:
                    results = results.append(results1,ignore_index = True)
                    #if counter != 0:
                    #    results = results.append(results1)
                    counter += 1
                    #embed()
                    #for i in range
                    #embed()
    #embed()
                results = pd.DataFrame(results)
                results.reset_index(drop=True)
                results.to_pickle('numerical_test.pkl')
                np.save('numerical_test.npy', results)
else:
    results = pd.read_pickle('numerical_test.pkl')


results = results.reset_index(drop=True)
present_type = np.unique(results.type)
#present_type = list(present_type).pop('samped threshold')
error = {}
#error = pd.DataFrame(error)
for i in range(len(present_type)):
    #data_baseline[data_baseline['dataset'] == set]
    if present_type[i] != 'samped threshold':
        result = results[np.array(results['type'] == present_type[i])]
        #result['result_frequency']
        #error[present_type[i]] =np.empty((4,7))
        #error[present_type[i]] = np.NaN
        error[present_type[i]] = []
        step_all = np.unique(results['step'])
        eod_fr_all = np.unique(results['eod_fr'])
        sampling_all = np.unique(results['sampling'])
        a_fe_all = np.unique(results['amplitude'])
        for r in range(len(result)):
            step = result.iloc[r]['step']
            eod_fr = result.iloc[r]['eod_fr']
            sampling = result.iloc[r]['sampling']
            eod_fr = result.iloc[r]['eod_fr']
            start = 5
            end = 3500
            # step = 25
            #step = [30, 25, 10]
            # step = [250]
            #for ss in range(len(step)):
            eod_fe, beat_corr, beats = find_beats(start, end, step, eod_fr)
            error[present_type[i]].append(np.mean((beat_corr-result.iloc[r]['result_frequency'])**2))
            #plt.title(present_type[i])
            #plt.plot(beat_corr)
            #plt.plot(result.iloc[r]['result_frequency'])
            #plt.show()

embed()
err = {}
for i in range(len(error)):
    if present_type[i] != 'samped threshold':
       err[present_type[i]] = np.mean(error[present_type[i]])
plt.bar(np.array(list(err.keys())),np.array(list(err.values())))#np.arange(0,len(err),1)
plt.savefig('..bars')

values = [error[k] for k in error]
maximum = [[]]*len(error)
minimum = [[]]*len(error)
for i in range(len(error)):
    if present_type[i] != 'samped threshold':
        maximum[i] = np.max(error[present_type[i]])
        minimum[i] = np.min(error[present_type[i]])
for e in range(len(error)):
    if present_type[e] != 'samped threshold':
        plt.subplot(3,4,e+1)
        plt.title(present_type[e])
        plt.imshow(np.array(error[present_type[e]]).reshape(-1,int(np.sqrt(len(error[present_type[e]])))), vmin = np.nanmin(minimum), vmax = np.nanmax(maximum))#
plt.colorbar()
plt.show()
embed()
#for i in range(len(results)):
#results = []
eod_fr = [500,639,734,820,952,1000,1492]
for ee in range(len(eod_fr)):
    factor = 200
    sampling_fish = 500
    step = 500
    win = 'w2'
    d = 1
    x = [ 1.5]#x = [ 1.5, 2.5,0.5,]
    time_range = 200 * delta_t
    #sampling = 112345
    sampling = [83425,98232,100000,112683]
    for s in range(len(sampling)):
        deviation_ms, deviation_s, deviation_dp = find_dev(x, sampling[s])
        start = 5
        end = 3500
        # step = 25
        step = [30, 25, 10]
        for ss in range(len(step)):
            eod_fe, beat_corr, beats = find_beats(start,end,step[ss],eod_fr[ee])
            delta_t = 1
            load = True
            size = [120]
            a_fr = 1
            a_fe = [0.5, 0.3, 0.2]
            for a in range(len(a_fe)):
                res_amp = results[np.array(results['amplitude'] == a_fe[a])]
                res_samp = res_amp[np.array(res_amp['sampling'] == sampling[s])]
                res_step = res_samp[np.array(res_samp['step'] == step[ss])]
                res_eod = res_step[np.array(res_step['eod_fr'] == eod_fr[ee])]
                #res_eod = res_eod[np.array(res_step['result frequency'] == eod_fr[ee])]

                res = res_eod.reset_index(drop=True)
                try:
                    fig = plot_power(beats, res,'w2', deviation_s, sigma, sampling[s], deviation_ms, beat_corr, size, [phase_zero[0]], delta_t, a_fr, a_fe[a], eod_fr[ee], eod_fe, deviation_dp, show_figure = True, plot_dist = False, save = True)#bef_c = bef_c,aft_c =aft_c

                    plt.suptitle('sampling '+str(sampling[s])+ 'step '+str(step[ss])+'eod_fr '+str(eod_fr[ee])+'amplitude '+str(a_fe[a]))
                    plt.savefig('../highbeats_pdf/numerical_test/sampling'+str(sampling[s])+ 'step'+str(step[ss])+'eod_fr'+str(eod_fr[ee])+'amplitude'+str(a_fe[a])+'.pdf')
                    plt.savefig('../highbeats_pdf/numerical_test/sampling' + str(sampling[s]) + 'step' + str(
                        step[ss]) + 'eod_fr' + str(eod_fr[ee]) + 'amplitude' + str(a_fe[a]) + '.png')
                except:
                    a = 2
                #plt.show()





#results = pd.read_pickle('numerical_test.pkl')
# embed()