from stimuli.SinusAmplitudeModulation import SinusAmplitudeModulationStimulus as SAM
from Baseline import get_baseline_class
from FiCurve import FICurveModel
from models.LIFACnoise import LifacNoiseModel
import numpy as np
import matplotlib.pyplot as plt
import helperFunctions as hF
from CellData import CellData
from ModelFit import ModelFit, get_best_fit
import os
import shutil


def main():
    run_sam_analysis_for_all_cells("results/final_2")

    # sam_analysis("results/final_2/2011-10-25-ad-invivo-1/")

    # plot_traces_with_spiketimes()
    # plot_mean_of_cuts()

    quit()
    modelfit = get_best_fit("results/final_2/2011-10-25-ad-invivo-1/")
    cell_data = CellData(modelfit.get_cell_path())

    eod_freq = cell_data.get_eod_frequency()
    model = modelfit.get_model()

    test_model_response(model, eod_freq, 0.1, np.arange(5, 2500, 5))


def run_sam_analysis_for_all_cells(folder):
    count = 0
    for item in os.listdir(folder):
        cell_folder = os.path.join(folder, item)
        fit = get_best_fit(cell_folder, use_comparable_error=False)
        cell_data = fit.get_cell_data()

        if cell_data.has_sam_recordings():
            count += 1
            # print("Fit quality:", fit.get_fit_routine_error())
            sam_analysis(cell_folder)
    print(count)



def test_model_response(model: LifacNoiseModel, eod_freq, contrast, modulation_frequencies):

    stds = []

    for m_freq in modulation_frequencies:
        if (1/m_freq) / 10 <= model.parameters["step_size"]:
            model.parameters["step_size"] = (1/m_freq) / 10
        step_size = model.parameters["step_size"]
        print("mode_freq:", m_freq, "- step size:",  step_size)
        stimulus = SAM(eod_freq, contrast / 100, m_freq)
        duration = 30
        v1, spikes_model = model.simulate(stimulus, duration)
        prob_density_function_model = spiketimes_calculate_pdf(spikes_model, step_size, kernel_width=0.005)

        fig, ax = plt.subplots(1, 1)
        ax.plot(prob_density_function_model)
        ax.set_title("pdf with m_freq: {}".format(int(m_freq)))

        plt.savefig("figures/sam/pdf_mfreq_{}.png".format(m_freq))
        plt.close()
        stds.append(np.std(prob_density_function_model))

    plt.plot((np.array(modulation_frequencies)) / eod_freq, stds)
    plt.show()
    plt.close()


def plot_traces_with_spiketimes():
    modelfit = get_best_fit("results/final_2/2011-10-25-ad-invivo-1/")
    cell_data = modelfit.get_cell_data()

    traces = cell_data.parser.__get_traces__("SAM")
    # [time_traces, v1_traces, eod_traces, local_eod_traces, stimulus_traces]
    sam_spiketimes = cell_data.get_sam_spiketimes()
    for i in range(len(traces[0])):
        fig, axes = plt.subplots(2, 1, sharex=True)
        axes[0].plot(traces[0][i], traces[1][i])
        axes[0].plot(list(sam_spiketimes[i]), list([max(traces[1][i])] * len(sam_spiketimes[i])), 'o')
        axes[1].plot(traces[0][i], traces[3][i])

        plt.show()
        plt.close()


def plot_mean_of_cuts():
    modelfit = get_best_fit("results/final_2/2018-05-08-ac-invivo-1/")

    if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")):
        print("Cell: {} \n Has no measured sam stimuli.")
        return
    cell_data = CellData(modelfit.get_cell_path())

    eod_freq = cell_data.get_eod_frequency()
    model = modelfit.get_model()

    durations = cell_data.get_sam_durations()
    u_durations = np.unique(durations)
    mean_duration = np.mean(durations)
    contrasts = cell_data.get_sam_contrasts()
    u_contrasts = np.unique(contrasts)
    contrast = contrasts[0]  # are all the same in this test case
    spiketimes = cell_data.get_sam_spiketimes()
    delta_freqs = cell_data.get_sam_delta_frequencies()
    step_size = cell_data.get_sampling_interval()

    spikes_dictionary = {}
    for i, m_freq in enumerate(delta_freqs):
        if m_freq in spikes_dictionary:
            spikes_dictionary[m_freq].append(spiketimes[i])
        else:
            spikes_dictionary[m_freq] = [spiketimes[i]]

    for m_freq in sorted(spikes_dictionary.keys()):
        if mean_duration < 2 * (1 / float(m_freq)):
            print("meep")
            continue
        stimulus = SAM(eod_freq, contrast / 100, m_freq)
        v1, spikes_model = model.simulate(stimulus, 4)
        prob_density_function_model = spiketimes_calculate_pdf(spikes_model, step_size)

        fig, axes = plt.subplots(1, 4)
        start_idx = int(2 / step_size)
        cuts = cut_pdf_into_periods(prob_density_function_model[start_idx:], 1 / float(m_freq), step_size)
        for c in cuts:
            axes[0].plot(c, color="gray", alpha=0.2)
        axes[0].set_title("model")
        mean_model = np.mean(cuts, axis=0)
        axes[0].plot(mean_model, color="black")

        means_cell = []
        for spikes_cell in spikes_dictionary[m_freq]:
            prob_density_cell = spiketimes_calculate_pdf(spikes_cell[0], step_size)

            cuts_cell = cut_pdf_into_periods(prob_density_cell, 1 / float(m_freq), step_size)
            for c in cuts_cell:
                axes[1].plot(c, color="gray", alpha=0.15)
            print(cuts_cell.shape)
            means_cell.append(np.mean(cuts_cell, axis=0))
        if len(means_cell) == 0:
            print("means cell length zero")
            continue
        means_cell = np.array(means_cell)
        total_mean_cell = np.mean(means_cell, axis=0)
        axes[1].set_title("cell")
        axes[1].plot(total_mean_cell, color="black")

        axes[2].set_title("difference")
        diff = [(total_mean_cell[i] - mean_model[i]) for i in range(len(total_mean_cell))]
        axes[2].plot(diff)

        axes[3].plot(total_mean_cell)
        axes[3].plot(mean_model)

        plt.show()
        plt.close()


def sam_analysis(fit_path):
    modelfit = get_best_fit(fit_path)

    if not os.path.exists(os.path.join(modelfit.get_cell_path(), "samallspikes1.dat")):
        print("Cell: {} \n Has no measured sam stimuli.")
        return
    cell_data = CellData(modelfit.get_cell_path())
    model = modelfit.get_model()

    # parameters = {'delta_a': 0.08820130374685671, 'refractory_period': 0.0006, 'a_zero': 15, 'step_size': 5e-05,
    #               'v_base': 0, 'noise_strength': 0.03622523883042496, 'v_zero': 0, 'threshold': 1,
    #               'input_scaling': 77.75367190909581, 'tau_a': 0.07623731247799118, 'v_offset': -10.546875,
    #               'mem_tau': 0.008741976196676469, 'dend_tau': 0.0012058986118892773}
    # model = LifacNoiseModel(parameters)
    # cell_data = CellData("./data/test_data/2012-12-13-an-invivo-1/")

    eod_freq = cell_data.get_eod_frequency()
    step_size = cell_data.get_sampling_interval()

    durations = cell_data.get_sam_durations()
    u_durations = np.unique(durations)
    contrasts = cell_data.get_sam_contrasts()
    u_contrasts = np.unique(contrasts)
    spiketimes = cell_data.get_sam_spiketimes()
    delta_freqs = cell_data.get_sam_delta_frequencies()
    u_delta_freqs = np.unique(delta_freqs)

    cell_stds = []
    model_stds = []

    approx_offset = approximate_axon_delay_in_idx(cell_data, model)
    print("Approx offset idx:", approx_offset)
    print("Approx offset ms:", (approx_offset * step_size) * 1000)

    for mod_freq in sorted(u_delta_freqs):
        # TODO problem of cutting the pdf as in some cases the pdf is shorter than 1 modulation frequency period!
        #  length info wrong ? always at least one period?

        # if 1/mod_freq > durations[0] / 4:
        #     print("skipped mod_freq: {}".format(mod_freq))
        #     print("Duration: {} while mod_freq period: {:.2f}".format(durations[0], 1/mod_freq))
        #     print("Maybe long enough duration? unique durations:", u_durations)
        #     continue
        mfreq_data = {}
        cell_means = []
        model_means = []
        for c in u_contrasts:
            mfreq_data[c] = []

        for i in range(len(delta_freqs)):
            if delta_freqs[i] != mod_freq:
                continue
            if len(spiketimes[i]) == 0:
                print("No spiketimes found at index!")
                continue
            if len(spiketimes[i]) > 1:
                print("There are more spiketimes in one 'point'! Only the first was used! ")


            spikes = spiketimes[i][0]

            cell_pdf = spiketimes_calculate_pdf(spikes, step_size)

            cell_cuts = cut_pdf_into_periods(cell_pdf, 1/mod_freq, step_size)
            cell_mean = np.mean(cell_cuts, axis=0)
            cell_means.append(cell_mean)

            stimulus = SAM(eod_freq, contrasts[i] / 100, mod_freq)
            v1, spikes_model = model.simulate(stimulus, 10)
            model_pdf = spiketimes_calculate_pdf(spikes_model, step_size)
            model_cuts = cut_pdf_into_periods(model_pdf, 1/mod_freq, step_size)
            model_mean = np.mean(model_cuts, axis=0)
            model_means.append(model_mean)

        min_length = min(min([len(cm) for cm in cell_means]), min([len(mm) for mm in model_means]))
        for i in range(len(cell_means)):
            cell_means[i] = cell_means[i][:min_length]
            model_means[i] = model_means[i][:min_length]
        final_cell_mean = np.mean(cell_means, axis=0)
        final_model_mean = np.mean(model_means, axis=0)
        cell_stds.append(np.std(final_cell_mean))
        model_stds.append(np.std(final_model_mean))
        # offset, final_model_mean_phase_corrected = correct_phase(final_cell_mean, final_model_mean, step_size)
        # print("Offset:", offset)
        # print("modfreq:", mod_freq)

        final_model_mean_phase_corrected = np.roll(final_model_mean, approx_offset)

        # PLOT EVERY MOD FREQ
        # fig, axes = plt.subplots(1, 5, figsize=(15, 5), sharex=True)
        # for c in cell_means:
        #     axes[0].plot(c, color="grey", alpha=0.2)
        # axes[0].plot(np.mean(cell_means, axis=0), color="black")
        # axes[0].set_title("Cell response")
        # axis_cell = axes[0].axis()
        #
        # for m in model_means:
        #     axes[1].plot(m, color="grey", alpha=0.2)
        # axes[1].plot(np.mean(model_means, axis=0), color="black")
        # axes[1].set_title("Model response")
        # axis_model = axes[1].axis()
        # ylim_top = max(axis_cell[3], axis_model[3])
        # axes[1].set_ylim(0, ylim_top)
        # axes[0].set_ylim(0, ylim_top)
        # axes[2].set_ylim(0, ylim_top)
        #
        # axes[2].plot(final_cell_mean, label="cell")
        # axes[2].plot(final_model_mean, label="model")
        # axes[2].plot(final_model_mean_phase_corrected, label="model p-cor")
        # axes[2].legend()
        # axes[2].set_title("cell-model overlapped")
        # axes[3].plot((final_model_mean - final_cell_mean) / final_cell_mean, label="normal")
        # axes[3].plot((final_model_mean_phase_corrected- final_cell_mean) / final_cell_mean, label="phase cor")
        # axes[3].set_title("rel. error")
        # axes[3].legend()
        # axes[4].plot(final_model_mean - final_cell_mean, label="normal")
        # axes[4].plot(final_model_mean_phase_corrected - final_cell_mean, label="phase cor")
        # axes[4].set_title("abs. error (Hz)")
        # axes[4].legend()
        #
        # fig.suptitle("modulation frequency: {}".format(mod_freq))
        #
        # # plt.tight_layout()
        # # plt.show()
        # plt.close()

    fig, ax = plt.subplots(1, 1)

    ax.plot(u_delta_freqs[-len(cell_stds):], cell_stds, label="cell stds")
    ax.plot(u_delta_freqs[-len(model_stds):], model_stds, label="model stds")
    ax.set_title("response modulation depth")
    ax.set_xlabel("Modulation frequency")
    ax.set_ylabel("STD")
    ax.legend()
    plt.savefig("figures/sam/" + cell_data.get_cell_name() + ".png")
    # plt.show()
    plt.close()


def correct_phase(cell_mean, model_mean, step_size):

    # test for every 0.2 ms roll in the total time:
    lowest_err = np.inf
    roll_idx = 0
    for i in range(int(len(cell_mean) * step_size * 1000) * 5):
        roll_by = int((i / 5 / 1000) / step_size)
        rolled = np.roll(model_mean, roll_by)
        # rms = np.sqrt(np.mean(np.power((cell_mean - rolled), 2)))
        abs = np.sum(np.abs(cell_mean-rolled))
        if abs < lowest_err:
            lowest_err = abs
            roll_idx = roll_by

    return roll_idx, np.roll(model_mean, roll_idx)


def approximate_axon_delay_in_idx(cell_data, model):

    lowest_mod_freq = 80
    highest_mod_freq = 150

    eod_freq = cell_data.get_eod_frequency()
    step_size = cell_data.get_sampling_interval()

    durations = cell_data.get_sam_durations()
    contrasts = cell_data.get_sam_contrasts()
    u_contrasts = np.unique(contrasts)
    spiketimes = cell_data.get_sam_spiketimes()
    delta_freqs = cell_data.get_sam_delta_frequencies()
    u_delta_freqs = np.unique(delta_freqs)

    used_mod_freqs = []
    axon_delays = []

    for mod_freq in sorted(u_delta_freqs):
        # Only use "stable" mod_freqs to approximate the axon delay
        if not lowest_mod_freq <= mod_freq <= highest_mod_freq:
            continue

        if 1/mod_freq > durations[0] / 4:
            print("skipped mod_freq: {}".format(mod_freq))
            print("Duration: {} while mod_freq period: {:.2f}".format(durations[0], 1/mod_freq))
            continue
        mfreq_data = {}
        cell_means = []
        model_means = []
        for c in u_contrasts:
            mfreq_data[c] = []

        for i in range(len(delta_freqs)):
            if delta_freqs[i] != mod_freq:
                continue

            if len(spiketimes[i]) > 1:
                print("There are more spiketimes in one 'point'! Only the first was used! ")
            spikes = spiketimes[i][0]

            cell_pdf = spiketimes_calculate_pdf(spikes, step_size)

            cell_cuts = cut_pdf_into_periods(cell_pdf, 1/mod_freq, step_size)
            if len(cell_cuts) == 0:
                continue
            cell_mean = np.mean(cell_cuts, axis=0)
            cell_means.append(cell_mean)

            stimulus = SAM(eod_freq, contrasts[i] / 100, mod_freq)
            v1, spikes_model = model.simulate(stimulus, durations[i] * 4)
            model_pdf = spiketimes_calculate_pdf(spikes_model, step_size)
            model_cuts = cut_pdf_into_periods(model_pdf, 1/mod_freq, step_size)
            model_mean = np.mean(model_cuts, axis=0)
            model_means.append(model_mean)

        final_cell_mean = np.mean(cell_means, axis=0)
        final_model_mean = np.mean(model_means, axis=0)

        offset, final_model_mean_phase_corrected = correct_phase(final_cell_mean, final_model_mean, step_size)

        used_mod_freqs.append(mod_freq)
        axon_delays.append(offset)

    mean_delay = np.mean(axon_delays)
    if np.isnan(mean_delay):
        return 0
    else:
        return int(round(mean_delay))


def generate_pdf(model, stimulus, trials=4, sim_length=3, kernel_width=0.005):

    trials_rate_list = []
    step_size = model.get_parameters()["step_size"]
    for _ in range(trials):
        v1, spikes = model.simulate_slow(stimulus, total_time_s=sim_length)

        binary = np.zeros(int(sim_length/step_size))
        spikes = [int(s / step_size) for s in spikes]
        for s_idx in spikes:
            binary[s_idx] = 1

        kernel = gaussian_kernel(kernel_width, step_size)
        rate = np.convolve(binary, kernel, mode='same')
        trials_rate_list.append(rate)

    times = [np.arange(0, sim_length, step_size) for _ in range(trials)]
    t, mean_rate = hF.calculate_mean_of_frequency_traces(times, trials_rate_list, step_size)

    return mean_rate


def spiketimes_calculate_pdf(spikes, step_size, kernel_width=0.001):
    length = int(spikes[len(spikes)-1] / step_size)+1
    binary = np.zeros(length)
    spikes = [int(s / step_size) for s in spikes]
    for s_idx in spikes:
        binary[s_idx] = 1

    kernel = gaussian_kernel(kernel_width, step_size)
    rate = np.convolve(binary, kernel, mode='same')

    return rate


def cut_pdf_into_periods(pdf, period, step_size, factor=0.0):

    if period < 0:
        # print("cut_pdf_into_periods(): Period was negative! Absolute value taken to continue")
        period = abs(period)

    idx_period_length = int(period / float(step_size))
    offset_per_step = period / float(step_size) - idx_period_length
    cut_length = idx_period_length + int(factor * idx_period_length)
    num_of_cuts = int(len(pdf) / (idx_period_length + offset_per_step))

    if len(pdf) - (num_of_cuts * idx_period_length + (num_of_cuts * offset_per_step)) < cut_length - idx_period_length:
        num_of_cuts -= 1

    if idx_period_length * 0.9 > len(pdf):
        return []
        # raise RuntimeError("SAM stimulus is too short for the given mod freq period.")

    if cut_length > len(pdf) or num_of_cuts < 1:
        return [pdf]

    cuts = np.zeros((num_of_cuts-1, cut_length))
    for i in np.arange(1, num_of_cuts, 1):
        offset_correction = int(offset_per_step * i)
        start_idx = i*idx_period_length + offset_correction
        end_idx = (i*idx_period_length)+cut_length + offset_correction
        cut = np.array(pdf[start_idx: end_idx])
        cuts[i-1] = cut

    if len(cuts.shape) < 2:
        print("Fishy....")
    return cuts


def gaussian_kernel(sigma, dt):
    x = np.arange(-4. * sigma, 4. * sigma, dt)
    y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
    return y


if __name__ == '__main__':
    main()