import matplotlib.pyplot as plt
import numpy as np
from filters import bandpass_filter, inst_freq, instantaneous_frequency
from fish_signal import chirps, wavefish_eods
from IPython import embed


def switch_test(test, defaultparams, testparams):
    if test == "width":
        defaultparams["chirp_width"] = testparams["chirp_width"]
        key = "chirp_width"
    elif test == "size":
        defaultparams["chirp_size"] = testparams["chirp_size"]
        key = "chirp_size"
    elif test == "kurtosis":
        defaultparams["chirp_kurtosis"] = testparams["chirp_kurtosis"]
        key = "chirp_kurtosis"
    elif test == "contrast":
        defaultparams["chirp_contrast"] = testparams["chirp_contrast"]
        key = "chirp_contrast"
    else:
        raise ValueError("Test not recognized")

    return key, defaultparams


def extract_dict(dict, index):
    return {key: value[index] for key, value in dict.items()}


def test(test1, test2, resolution=10):
    assert test1 in [
        "width",
        "size",
        "kurtosis",
        "contrast",
    ], "Test1 not recognized"

    assert test2 in [
        "width",
        "size",
        "kurtosis",
        "contrast",
    ], "Test2 not recognized"

    # Define the parameters for the chirp simulations
    ntest = resolution

    defaultparams = dict(
        chirp_size=np.ones(ntest) * 100,
        chirp_width=np.ones(ntest) * 0.1,
        chirp_kurtosis=np.ones(ntest) * 1.0,
        chirp_contrast=np.ones(ntest) * 0.5,
    )

    testparams = dict(
        chirp_width=np.linspace(0.01, 0.2, ntest),
        chirp_size=np.linspace(50, 300, ntest),
        chirp_kurtosis=np.linspace(0.5, 1.5, ntest),
        chirp_contrast=np.linspace(0.01, 1.0, ntest),
    )

    key1, chirp_params = switch_test(test1, defaultparams, testparams)
    key2, chirp_params = switch_test(test2, chirp_params, testparams)

    # make the chirp trace
    eodf = 500
    samplerate = 20000
    duration = 2
    chirp_times = [0.5, 1, 1.5]

    wide_cutoffs = 200
    tight_cutoffs = 10

    distances = np.full((ntest, ntest), np.nan)

    fig, axs = plt.subplots(
        ntest, ntest, figsize=(10, 10), sharex=True, sharey=True
    )
    axs = axs.flatten()

    iter0 = 0
    for iter1, test1_param in enumerate(chirp_params[key1]):
        for iter2, test2_param in enumerate(chirp_params[key2]):
            # get the chirp parameters for the current test
            inner_chirp_params = extract_dict(chirp_params, iter2)
            inner_chirp_params[key1] = test1_param
            inner_chirp_params[key2] = test2_param

            # make the chirp trace for the current chirp parameters
            sizes = np.ones(len(chirp_times)) * inner_chirp_params["chirp_size"]
            widths = (
                np.ones(len(chirp_times)) * inner_chirp_params["chirp_width"]
            )
            kurtosis = (
                np.ones(len(chirp_times)) * inner_chirp_params["chirp_kurtosis"]
            )
            contrast = (
                np.ones(len(chirp_times)) * inner_chirp_params["chirp_contrast"]
            )

            # make the chirp trace
            chirp_trace, ampmod = chirps(
                eodf,
                samplerate,
                duration,
                chirp_times,
                sizes,
                widths,
                kurtosis,
                contrast,
            )
            signal = wavefish_eods(
                fish="Alepto",
                frequency=chirp_trace,
                samplerate=samplerate,
                duration=duration,
                phase0=0.0,
                noise_std=0.05,
            )
            signal = signal * ampmod

            # apply broadband filter
            wide_signal = bandpass_filter(
                signal, samplerate, eodf - wide_cutoffs, eodf + wide_cutoffs
            )
            tight_signal = bandpass_filter(
                signal, samplerate, eodf - tight_cutoffs, eodf + tight_cutoffs
            )

            # get the instantaneous frequency
            wide_frequency = inst_freq(wide_signal, samplerate)
            tight_frequency = inst_freq(tight_signal, samplerate)

            bool_mask = wide_frequency != 0
            axs[iter0].plot(wide_frequency[bool_mask])
            axs[iter0].plot(tight_frequency[bool_mask])
            fig.supylabel(key1)
            fig.supxlabel(key2)

            iter0 += 1

    plt.show()

def main():
    test("contrast", "kurtosis")


if __name__ == "__main__":
    main()