import numpy as np
import matplotlib.pyplot as plt
import helperFunctions as hF
import time

def main():


    time = np.arange(-1, 30, 0.0001)
    eod = np.sin(2*np.pi * 600 * time)


    signs = np.sign(eod[:-1]) != np.sign(eod[1:])
    delta = eod[:-1] < eod[1:]

    sign_changes = np.where(signs & delta)[0]

    plt.plot(time, eod)
    plt.plot([time[i] for i in sign_changes], [0]*len(sign_changes), 'o')
    plt.show()


    print(sign_changes)
    quit()


    for freq in [700, 50, 100, 500, 1000]:
        reps = 1000
        start = time.time()
        for i in range(reps):
            mean_isi = 1 / freq
            n = 0.7
            phase_locking_strength = 0.7
            size = 100000

            final_isis = np.array([])
            while len(final_isis) < size:

                isis = np.random.normal(mean_isi, mean_isi*n, size)
                isi_phase = (isis % mean_isi) / mean_isi
                diff = abs_phase_diff(isi_phase, 0.5)
                chance = np.random.random(size)

                isis_phase_cleaned = []
                for i in range(len(diff)):
                    if 1-diff[i]**0.05 > chance[i]:
                        isis_phase_cleaned.append(isis[i])

                final_isis = np.concatenate((final_isis, isis_phase_cleaned))

            spikes = np.cumsum(final_isis)
            spikes = np.sort(spikes[spikes > 0])
            clean_isis = np.diff(spikes)

            bins = np.arange(-0.01, 0.01, 0.0001)
            plt.hist(clean_isis, alpha=0.5, bins=bins)
            plt.hist(isis, alpha=0.5, bins=bins)
            plt.show()
            quit()

        end = time.time()

        print("It took {:.2f} s to simulate 10s of spikes at {} Hz".format(end-start, freq))


def abs_phase_diff(rel_phases:list, ref_phase:float):
    """

    :param rel_phases: relative phases as a list of values between 0 and 1
    :param ref_phase: reference phase to which the difference is calculated (between 0 and 1)
    :return: list of absolute differences
    """

    diff = [abs(min(x-ref_phase, x-ref_phase+1)) for x in rel_phases]

    return diff


if __name__ == '__main__':
    main()