import sys
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from thunderfish.dataloader import open_data
from thunderfish.eodanalysis import eod_waveform
from IPython import embed
import matplotlib.gridspec as gridspec
from params import *


def unfilter(data, samplerate, cutoff):
    """
    Apply inverse high-pass filter on data.

    Assumes high-pass filter \\[ \\tau \\dot y = -y + \\tau \\dot x \\] has
    been applied on the original data \\(x\\), where \\(\tau=(2\\pi
    f_{cutoff})^{-1}\\) is the time constant of the filter. To recover \\(x\\)
    the ODE \\[ \\tau \\dot x = y + \\tau \\dot y \\] is applied on the
    filtered data \\(y\\).

    Parameters:
    -----------
    data: ndarray
        High-pass filtered original data.
    samplerate: float
        Sampling rate of `data` in Hertz.
    cutoff: float
        Cutoff frequency \\(f_{cutoff}\\) of the high-pass filter in Hertz.

    Returns:
    --------
    data: ndarray
        Recovered original data.
    """
    tau = 0.5 / np.pi / cutoff
    fac = tau * samplerate
    data -= np.mean(data)
    d0 = data[0]
    x = d0
    for k in range(len(data)):
        d1 = data[k]
        x += (d1 - d0) + d0 / fac
        data[k] = x
        d0 = d1
    return data


def calc_mean_eod(t0, f, data, dt=10, unfilter=0):
    channel_list = np.arange(data.channels)
    samplerate = data.samplerate

    start_i = t0 * samplerate
    end_i = t0 * samplerate + dt * samplerate + 1
    t = np.arange(0, dt, 1 / f)

    mean_EODs = []
    for c in channel_list:
        mean_eod, eod_times = eod_waveform(data[start_i:end_i, c], samplerate, t, unfilter_cutoff=unfilter)
        mean_EODs.append(mean_eod)

    max_size = list(map(lambda x: np.max(x.T[1]) - np.min(x.T[1]), mean_EODs))
    EOD = mean_EODs[np.argmax(max_size)]

    return EOD, samplerate


def main(folder, filename):
    # folder = path_to_files
    data = open_data(os.path.join(folder, 'traces-grid1.raw'), -1, 60.0, 10.0)

    power_means = np.load('../data/' + filename + '/power_means.npy', allow_pickle=True)
    all_q10 = np.load('../data/' + filename + '/fish_freq_q10.npy', allow_pickle=True)
    all_t = np.load('../data/' + filename + '/eod_times_new_new.npy', allow_pickle=True)
    all_f = np.load('../data/' + filename + '/eod_freq_new_new.npy', allow_pickle=True)

    plot_pannel = [16, 0]
    cutoff_value = [200, 0]
    y_ticks = [[-0.001, 0, 0.001, 0.0015], [-0.002, 0, 0.002]]

    ##################################################################################################################
    # figure
    fig = plt.figure(constrained_layout=True, figsize=[15 / inch, 6 / inch])
    gs = gridspec.GridSpec(ncols=2, nrows=1, figure=fig, hspace=0.05, wspace=0.0,
                           left=0.1, bottom=0.15, right=0.95, top=0.98)

    ax2 = fig.add_subplot(gs[0, 1])
    ax1 = fig.add_subplot(gs[0, 0], sharey=ax2)

    for fn_idx, fish_number, ax in zip([0, 1], [15, 22], [ax1, ax2]):
        print(all_q10[fish_number, 2], fish_number)

        t = all_t[fish_number][plot_pannel[fn_idx]]
        f = all_f[fish_number][plot_pannel[fn_idx]]
        EOD, samplingrate = calc_mean_eod(t, f, data, unfilter=cutoff_value[fn_idx])

        ##############################################################################################################
        # plot
        ax.plot(EOD.T[0], EOD.T[1], color=color_efm[fn_idx], lw=2)
        ax.fill_between(EOD.T[0], EOD.T[1] + EOD.T[2], EOD.T[1] - EOD.T[2],
                        color=color_efm[fn_idx], alpha=0.7)
        ax.make_nice_ax()

        ax.text(-0.12, 0.95, chr(ord('A') + fn_idx), transform=ax.transAxes, fontsize='large')
        ax.text(0.8, 0.95, str(np.round(all_q10[fish_number, 2], 1))+' Hz', transform=ax.transAxes, fontsize=10)

        ax.set_xlabel('Time')
        ax.set_yticks([0])
        ax.set_xticks([])
        # fig.suptitle(all_q10[fish_number, 2])

    ax1.set_ylabel('Amplitude')
    fig.savefig(save_path + 'eod_waves.pdf')

    plt.show()


if __name__ == '__main__':

    for index, filename_idx in enumerate([2]):
        filename = sorted(os.listdir('../../../data/mount_data/sanmartin/softgrid_1x16/'))[filename_idx]
        folder = '../../../data/mount_data/sanmartin/softgrid_1x16/' + filename
        print('new file: ' + filename)
        main(folder, filename)