import unittest
import numpy as np
import helperFunctions as hF
import matplotlib.pyplot as plt


class HelperFunctionsTester(unittest.TestCase):

    noise_levels = [0, 0.05, 0.1, 0.2]
    frequencies = [0, 1, 5, 30, 100, 500, 750, 1000]

    def setUp(self):
        pass

    def tearDown(self):
        pass

    def test_calculate_eod_frequency(self):
        start = 0
        end = 5
        step = 0.1 / 1000
        freqs = [0, 1, 10, 500, 700, 1000]
        for freq in freqs:
            time = np.arange(start, end, step)
            eod = np.sin(freq*(2*np.pi) * time)
            self.assertEqual(freq, round(hF.calculate_eod_frequency(time, eod), 2))

    def test__vector_strength__is_1(self):
        length = 2000
        rel_spike_times = np.full(length, 0.3)
        eod_durations = np.full(length, 0.14)

        self.assertEqual(1, round(hF.__vector_strength__(rel_spike_times,eod_durations), 2))

    def test__vector_strength__is_0(self):
        length = 2000
        period = 0.14
        rel_spike_times = np.arange(0, period, period/length)
        eod_durations = np.full(length, period)

        self.assertEqual(0, round(hF.__vector_strength__(rel_spike_times, eod_durations), 5))

    def test_mean_freq_of_spiketimes_after_time_x(self):
        simulation_time = 8
        for freq in self.frequencies:
            for n in self.noise_levels:
                spikes = generate_jittered_spiketimes(freq, n, end=simulation_time)
                sim_freq = hF.mean_freq_of_spiketimes_after_time_x(spikes, 0.00005, simulation_time/4, time_in_ms=False)

                max_diff = round(n*(10+0.7*np.sqrt(freq)), 2)
                # print("noise: {:.2f}".format(n), "\texpected: {:.2f}".format(freq), "\tgotten: {:.2f}".format(round(sim_freq, 2)), "\tfreq diff: {:.2f}".format(abs(freq-round(sim_freq, 2))), "\tmax_diff:", max_diff)
                self.assertTrue(abs(freq-round(sim_freq)) <= max_diff, msg="expected freq: {:.2f} vs calculated: {:.2f}. max diff was {:.2f}".format(freq, sim_freq, max_diff))

    def test_calculate_isi_frequency(self):
        simulation_time = 1
        sampling_interval = 0.00005

        for freq in self.frequencies:
            for n in self.noise_levels:
                spikes = generate_jittered_spiketimes(freq, n, end=simulation_time)
                sim_freq = hF.calculate_isi_frequency(spikes, sampling_interval, time_in_ms=False)

                isis = np.diff(spikes)
                step_length = isis / sampling_interval
                rounded_step_length = np.around(step_length)
                expected_length = sum(rounded_step_length)

                length = len(sim_freq)
                self.assertEqual(expected_length, length)



    # def test(self):
    #    test_distribution()

def generate_jittered_spiketimes(frequency, noise_level, start=0, end=5):
    if frequency == 0:
        return []

    mean_isi = 1 / frequency
    if noise_level == 0:
        return np.arange(start, end, mean_isi)

    spikes = [start]
    count = 0
    while True:
        next_isi = np.random.normal(mean_isi, noise_level*mean_isi)
        if next_isi <= 0:
            count += 1
            continue
        next_spike = spikes[-1] + next_isi
        if next_spike > end:
            break
        spikes.append(spikes[-1] + next_isi)

    # print("count: {:} percentage of missed: {:.2f}".format(count, count/len(spikes)))
    if count > 0.01*len(spikes):
        print("!!! Danger of lowering actual simulated frequency")
        pass
    return spikes


def test_distribution():
    simulation_time = 5
    freqs = [5, 30, 100, 500, 1000]
    noise_level = [0.05, 0.1, 0.2, 0.3]
    repetitions = 1000
    for freq in freqs:
        diffs_per_noise = []
        for n in noise_level:
            diffs = []
            print("#### - freq:", freq, "noise level:", n )
            for reps in range(repetitions):
                spikes = generate_jittered_spiketimes(freq, n, end=simulation_time)
                sim_freq = hF.mean_freq_of_spiketimes_after_time_x(spikes, 0.0002, simulation_time / 4, time_in_ms=False)
                diffs.append(sim_freq-freq)

            diffs_per_noise.append(diffs)

        fig, axs = plt.subplots(1, len(noise_level), figsize=(3.5*len(noise_level), 4), sharex='all')

        for i in range(len(diffs_per_noise)):
            max_diff = np.max(np.abs(diffs_per_noise[i]))
            print("Freq: ", freq, "noise: {:.2f}".format(noise_level[i]), "mean: {:.2f}".format(np.mean(diffs_per_noise[i])), "max_diff: {:.4f}".format(max_diff))
            bins = np.arange(-max_diff, max_diff, 2*max_diff/100)
            axs[i].hist(diffs_per_noise[i], bins=bins)
            axs[i].set_title('Noise level: {:.2f}'.format(noise_level[i]))

    plt.show()
    plt.close()