import numpy as np
import matplotlib.pyplot as plt
import pyrelacs.DataLoader as dl
import os
import helperFunctions as hf
from IPython import embed
from scipy.optimize import curve_fit
import warnings

SAMPLING_INTERVAL = 1/20000
STIMULUS_START = 0
STIMULUS_DURATION = 0.400
PRE_DURATION = 0.250
TOTAL_DURATION = 1.25


def main():
    for folder in hf.get_subfolder_paths("data/"):
        filepath = folder + "/fispikes1.dat"
        set_savepath("figures/" + folder.split('/')[1] + "/")
        print("Folder:", folder)

        if not os.path.exists(get_savepath()):
            os.makedirs(get_savepath())

        spiketimes = []
        intensities = []
        index = -1
        for metadata, key, data in dl.iload(filepath):
            # embed()
            if len(metadata) != 0:

                metadata_index = 0
                if '----- Control --------------------------------------------------------' in metadata[0].keys():
                    metadata_index = 1

                print(metadata)
                i = float(metadata[metadata_index]['intensity'][:-2])
                intensities.append(i)
                spiketimes.append([])
                index += 1

            spiketimes[index].append(data[:, 0]/1000)

        intensities, spiketimes = hf.merge_similar_intensities(intensities, spiketimes)

        # Sort the lists so that intensities are increasing
        x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
        intensities = x[0]
        spiketimes = x[1]

        mean_frequencies = calculate_mean_frequencies(intensities, spiketimes)
        popt, pcov = fit_exponential(intensities, mean_frequencies)
        plot_frequency_curve(intensities, mean_frequencies)

        f_baseline = calculate_f_baseline(mean_frequencies)
        f_infinity = calculate_f_infinity(mean_frequencies)
        f_zero = calculate_f_zero(mean_frequencies)

        # plot_fi_curve(intensities, f_baseline, f_zero, f_infinity)


# TODO !!
def fit_exponential(intensities, mean_frequencies):
    start_idx = int((PRE_DURATION + STIMULUS_START+0.005) / SAMPLING_INTERVAL)
    end_idx = int((PRE_DURATION + STIMULUS_START + 0.1) / SAMPLING_INTERVAL)
    time_constants = []
    #print(start_idx, end_idx)

    popts = []
    pcovs = []
    for i in range(len(mean_frequencies)):
        freq = mean_frequencies[i]
        y_values = freq[start_idx:end_idx+1]
        x_values = np.arange(start_idx*SAMPLING_INTERVAL, end_idx*SAMPLING_INTERVAL, SAMPLING_INTERVAL)
        try:
            popt, pcov = curve_fit(exponential_function, x_values, y_values, p0=(1/(np.power(1, 10)), .5, 50, 180), maxfev=10000)
        except RuntimeError:
            print("RuntimeError happened in fit_exponential.")
            continue
        #print(popt)
        #print(pcov)
        #print()

        popts.append(popt)
        pcovs.append(pcov)

        plt.plot(np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL), freq)
        plt.plot(x_values-PRE_DURATION, [exponential_function(x, popt[0], popt[1], popt[2], popt[3]) for x in x_values])
        # plt.show()
        save_path = get_savepath() + "exponential_fits/"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        plt.savefig(save_path + "fit_intensity:" + str(round(intensities[i], 4)) + ".png")
        plt.close()

    return popts, pcovs


def calculate_mean_frequency(freqs):
    mean_freq = [sum(e) / len(e) for e in zip(*freqs)]

    return mean_freq


def gaussian_kernel(sigma, dt):
    x = np.arange(-4. * sigma, 4. * sigma, dt)
    y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
    return y


def calculate_kernel_frequency(spiketimes, time, sampling_interval):
    sp = spiketimes
    t = time  # Probably goes from -200ms to some amount of  ms in the positive ~1200?
    dt = sampling_interval
    kernel_width = 0.01  # kernel width is a time in seconds how sharp the frequency should be counted

    binary = np.zeros(t.shape)
    spike_indices = ((sp - t[0]) / dt).astype(int)
    binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
    g = gaussian_kernel(kernel_width, dt)

    rate = np.convolve(binary, g, mode='same')

    return rate


def calculate_isi_frequency(spiketimes, time):
    first_isi = spiketimes[0] - (-PRE_DURATION)  # diff to the start at 0
    last_isi = TOTAL_DURATION - spiketimes[-1]  # diff from the last spike to the end of time :D
    isis = [first_isi]
    isis.extend(np.diff(spiketimes))
    isis.append(last_isi)

    if np.isnan(first_isi):
        print(spiketimes[:10])
        print(isis[0:10])
        quit()

    rate = []
    for isi in isis:
        if isi == 0:
            print("probably a problem")
            isi = 0.0000000001
        freq = 1/isi
        frequency_step = int(round(isi*(1/SAMPLING_INTERVAL)))*[freq]
        rate.extend(frequency_step)


    #plt.plot((np.arange(len(rate))-PRE_DURATION)/(1/SAMPLING_INTERVAL), rate)
    #plt.plot([sum(isis[:i+1]) for i in range(len(isis))], [200 for i in isis], 'o')
    #plt.plot(time, [100 for t in time])
    #plt.show()

    if len(rate) != len(time):
        if "12-13-af" in get_savepath():
            warnings.warn("preStimulus duration > 0 still not supported")
            return [1]*len(time)
        else:
            print(len(rate), len(time), len(rate) - len(time))
            print(rate)
            print(isis)
            print("Quitting because time and rate aren't the same length")
            quit()

    return rate


def calculate_mean_frequencies(intensities, spiketimes):
    time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL)

    mean_frequencies = []
    for i in range(len(intensities)):
        freqs = []
        for spikes in spiketimes[i]:
            if len(spikes) < 2:
                continue
            freq = calculate_isi_frequency(spikes, time)
            freqs.append(freq)

        mf = calculate_mean_frequency(freqs)
        mean_frequencies.append(mf)

    return mean_frequencies


def calculate_f_baseline(mean_frequencies):
    buffer_time = 0.05
    start_idx = int(0.05/SAMPLING_INTERVAL)
    end_idx = int((PRE_DURATION - STIMULUS_START - buffer_time)/SAMPLING_INTERVAL)

    f_zeros = []
    for freq in mean_frequencies:
        f_0 = np.mean(freq[start_idx:end_idx])
        f_zeros.append(f_0)

    return f_zeros


def calculate_f_infinity(mean_frequencies):
    buffer_time = 0.05
    start_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - 0.15 - buffer_time) / SAMPLING_INTERVAL)
    end_idx = int((PRE_DURATION + STIMULUS_START + STIMULUS_DURATION - buffer_time) / SAMPLING_INTERVAL)

    f_infinity = []
    for freq in mean_frequencies:
        f_inf = np.mean(freq[start_idx:end_idx])
        f_infinity.append(f_inf)

    return f_infinity


def calculate_f_zero(mean_frequencies):
    buffer_time = 0.1
    start_idx = int((PRE_DURATION + STIMULUS_START - buffer_time) / SAMPLING_INTERVAL)
    end_idx = int((PRE_DURATION + STIMULUS_START + buffer_time) / SAMPLING_INTERVAL)
    f_peaks = []
    for freq in mean_frequencies:
        fp = np.mean(freq[start_idx-500:start_idx])
        for i in range(start_idx+1, end_idx):
            if abs(freq[i] - freq[start_idx]) > abs(fp - freq[start_idx]):
                fp = freq[i]
        f_peaks.append(fp)
    return f_peaks


def plot_fi_curve(intensities, f_baseline, f_zero, f_infinity):
    plt.plot(intensities, f_baseline, label="f_baseline")
    plt.plot(intensities, f_zero, 'o', label="f_zero")
    plt.plot(intensities, f_infinity, label="f_infinity")

    max_f0 = float(max(f_zero))
    mean_int = float(np.mean(intensities))
    start_k = float(((f_zero[-1] - f_zero[0]) / (intensities[-1] - intensities[0])*4)/f_zero[-1])

    popt, pcov = curve_fit(fill_boltzmann, intensities, f_zero, p0=(max_f0, start_k, mean_int),  maxfev=10000)
    print(popt)
    min_x = min(intensities)
    max_x = max(intensities)
    step = (max_x - min_x) / 5000
    x_values_boltzmann_fit = np.arange(min_x, max_x, step)
    plt.plot(x_values_boltzmann_fit, [fill_boltzmann(i, popt[0], popt[1], popt[2]) for i in x_values_boltzmann_fit], label='fit')

    plt.title("FI-Curve")
    plt.ylabel("Frequency in Hz")
    plt.xlabel("Intensity in mV")
    plt.legend()
    # plt.show()
    plt.savefig(get_savepath() + "fi_curve.png")
    plt.close()


def plot_frequency_curve(intensities, mean_frequencies):
    colors = ["red", "green", "blue", "violet", "orange", "grey"]

    time = np.arange(-PRE_DURATION, TOTAL_DURATION, SAMPLING_INTERVAL)

    for i in range(len(intensities)):
        plt.plot(time, mean_frequencies[i], color=colors[i % 6], label=str(intensities[i]))

    plt.plot((0, 0), (0, 500), color="black")
    plt.plot((0.4, 0.4), (0, 500), color="black")
    plt.legend()
    plt.xlabel("Time in seconds")
    plt.ylabel("Frequency in Hz")
    plt.title("Frequency curve")

    plt.savefig(get_savepath() + "mean_frequency_curves.png")
    plt.close()


def exponential_function(x, a, b, c, d):
    return a*np.exp(-c*(x-b))+d


def upper_boltzmann(x, f_max, k, x_zero):
    return f_max * np.clip((2 / (1+np.power(np.e, -k*(x - x_zero)))) - 1, 0, None)


def fill_boltzmann(x, f_max, k, x_zero):
    return f_max * (1 / (1 + np.power(np.e, -k * (x - x_zero))))


SAVEPATH = ""


def get_savepath():
    global SAVEPATH
    return SAVEPATH


def set_savepath(new_path):
    global SAVEPATH
    SAVEPATH = new_path


if __name__ == '__main__':
    main()