import pyrelacs.DataLoader as dl
import numpy as np
import matplotlib.pyplot as plt
from IPython import embed
import os
import introduction.old_helper_functions as ohf
from thunderfish.eventdetection import detect_peaks


SAVEPATH = ""


def get_savepath():
    global SAVEPATH
    return SAVEPATH


def set_savepath(new_path):
    global SAVEPATH
    SAVEPATH = new_path


def main():
    for folder in ohf.get_subfolder_paths("data/"):
        filepath = folder + "/basespikes1.dat"
        set_savepath("figures/" + folder.split('/')[1] + "/")

        print("Folder:", folder)

        if not os.path.exists(get_savepath()):
            os.makedirs(get_savepath())

        spiketimes = []

        ran = False
        for metadata, key, data in dl.iload(filepath):
            ran = True
            spikes = data[:, 0]
            spiketimes.append(spikes)  # save for calculation of vector strength
            metadata = metadata[0]
            #print(metadata)
            # print('firing frequency1:', metadata['firing frequency1'])
            # print(mean_firing_rate(spikes))

            # print('Coefficient of Variation (CV):', metadata['CV1'])
            # print(calculate_coefficient_of_variation(spikes))

        if not ran:
            print("------------ DIDN'T RUN")

        isi_histogram(spiketimes)

        times, eods = ohf.get_traces(folder, 2, 'BaselineActivity')
        times, v1s = ohf.get_traces(folder, 1, 'BaselineActivity')

        vs = calculate_vector_strength(times, eods, spiketimes, v1s)

        # print("Calculated vector strength:", vs)


def mean_firing_rate(spiketimes):
    # mean firing rate (number of spikes per time)
    return len(spiketimes)/spiketimes[-1]*1000


def calculate_coefficient_of_variation(spiketimes):
    # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
    isi = np.diff(spiketimes)
    std = np.std(isi)
    mean = np.mean(isi)

    return std/mean


def isi_histogram(spiketimes):
    # ISI histogram (play around with binsize! < 1ms)

    isi = []
    for spike_list in spiketimes:
        isi.extend(np.diff(spike_list))
    maximum = max(isi)
    bins = np.arange(0, maximum*1.01, 0.1)

    plt.title('Phase locking of ISI without stimulus')
    plt.xlabel('ISI in ms')
    plt.ylabel('Count')
    plt.hist(isi, bins=bins)
    plt.savefig(get_savepath() + 'phase_locking_without_stimulus.png')
    plt.close()


def calculate_vector_strength(times, eods, spiketimes, v1s):
    # Vectorstaerke (use EOD frequency from header (metadata)) VS > 0.8
    # dl.iload_traces(repro='BaselineActivity')

    relative_spike_times = []
    eod_durations = []

    if len(times) == 0:
        print("-----LENGTH OF TIMES = 0")

    for recording in range(len(times)):

        rel_spikes, eod_durs = eods_around_spikes(times[recording], eods[recording], spiketimes[recording])
        relative_spike_times.extend(rel_spikes)
        eod_durations.extend(eod_durs)

        vs = __vector_strength__(np.array(rel_spikes), eod_durs)
        phases = calculate_phases(rel_spikes, eod_durs)
        plot_polar(phases, "test_phase_locking_" + str(recording) + "_with_vs:" + str(round(vs, 3)) + ".png")

        print("VS of recording", recording, ":", vs)

        plot_phaselocking_testfigures(times[recording], eods[recording], spiketimes[recording], v1s[recording])

    return __vector_strength__(np.array(relative_spike_times), eod_durations)


def eods_around_spikes(time, eod, spiketimes):
    eod_durations = []
    relative_spike_times = []

    for spike in spiketimes:
        index = spike * 20  # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx

        if index != np.round(index):
            print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index)
            continue
        index = int(index)

        start_time, end_time = search_eod_start_and_end_times(time, eod, index)

        eod_durations.append(end_time-start_time)
        relative_spike_times.append(spike/1000 - start_time)

    return relative_spike_times, eod_durations


def search_eod_start_and_end_times(time, eod, index):
    # TODO might break if a spike is in the cut off first or last eod!

    # search start_time:
    previous = index
    working_idx = index-1
    while True:
        if eod[working_idx] < 0 < eod[previous]:
            first_value = eod[working_idx]
            second_value = eod[previous]

            dif = second_value - first_value
            part = np.abs(first_value/dif)

            time_dif = np.abs(time[previous] - time[working_idx])
            start_time = time[working_idx] + time_dif*part

            break

        previous = working_idx
        working_idx -= 1

    # search end_time
    previous = index
    working_idx = index + 1
    while True:
        if eod[previous] < 0 < eod[working_idx]:
            first_value = eod[previous]
            second_value = eod[working_idx]

            dif = second_value - first_value
            part = np.abs(first_value / dif)

            time_dif = np.abs(time[previous] - time[working_idx])
            end_time = time[working_idx] + time_dif * part

            break

        previous = working_idx
        working_idx += 1

    return start_time, end_time


def search_closest_index(array, value, start=0, end=-1):
    # searches the array to find the closest value in the array to the given value and returns its index.
    # expects sorted array!
    # start hast to be smaller than end

    if end == -1:
        end = len(array)-1

    while True:
        if end-start <= 1:
            return end if np.abs(array[end]-value) < np.abs(array[start]-value) else start

        middle = int(np.floor((end-start)/2)+start)
        if array[middle] == value:
            return middle
        elif array[middle] > value:
            end = middle
            continue
        else:
            start = middle
            continue


def __vector_strength__(relative_spike_times, eod_durations):
    # adapted from Ramona

    n = len(relative_spike_times)
    if n == 0:
        return 0

    phase_times = (relative_spike_times / eod_durations) * 2 * np.pi
    vs = np.sqrt((1 / n * sum(np.cos(phase_times))) ** 2 + (1 / n * sum(np.sin(phase_times))) ** 2)

    return vs


def calculate_phases(relative_spike_times, eod_durations):
    phase_times = np.zeros(len(relative_spike_times))

    for i in range(len(relative_spike_times)):
        phase_times[i] = (relative_spike_times[i] / eod_durations[i]) * 2 * np.pi

    return phase_times


def plot_polar(phases, name=""):
    fig = plt.figure()
    ax = fig.add_subplot(111, polar=True)
    # r = np.arange(0, 1, 0.001)
    # theta = 2 * 2 * np.pi * r
    # line, = ax.plot(theta, r, color='#ee8d18', lw=3)
    bins = np.arange(0, np.pi*2, 0.05)
    ax.hist(phases, bins=bins)
    if name == "":
        plt.show()
    else:
        plt.savefig(get_savepath() + name)
        plt.close()


def plot_phaselocking_testfigures(time, eod, spiketimes, v1):
    eod_start_times = []
    eod_end_times = []

    for spike in spiketimes:
        index = spike * 20  # time in s given timestamp of spike in ms - recorded at 20kHz -> timestamp/1000*20000 = idx

        if index != np.round(index):
            print("INDEX NOT AN INTEGER in eods_around_spikes! index:", index)
            continue
        index = int(index)

        start_time, end_time = search_eod_start_and_end_times(time, eod, index)

        eod_start_times.append(start_time)
        eod_end_times.append(end_time)

    cutoff_in_sec = 2
    sampling = 20000
    max_idx = cutoff_in_sec*sampling
    spikes_part = [x/1000 for x in spiketimes if x/1000 < cutoff_in_sec]
    count_spikes = len(spikes_part)
    print(spiketimes)
    print(len(spikes_part))

    x_axis = time[0:max_idx]
    plt.plot(spikes_part, np.ones(len(spikes_part))*-20, 'o')
    plt.plot(x_axis, v1[0:max_idx])
    plt.plot(eod_start_times[: count_spikes], np.zeros(count_spikes), 'o')
    plt.plot(eod_end_times[: count_spikes], np.zeros(count_spikes), 'o')

    plt.show()
    plt.close()


if __name__ == '__main__':
    main()