from stimuli.SinusAmplitudeModulation import SinusAmplitudeModulationStimulus
import unittest
import numpy as np
import matplotlib.pyplot as plt


class SinusoidalStimulusTester(unittest.TestCase):

    base_frequencies = [0, 10, 100, 1000]
    contrasts = [0, 0.5, 1, 1.5]
    modulation_frequencies = [0, 5, 10, 100]
    step_sizes = [1, 0.5, 0.00005]
    time_starts = [0, 2, -2]
    durations = [2]

    def setUp(self):
        pass

    def tearDown(self):
        pass

    def test_consistency_base_frequency(self):
        contrast = 0.1
        mod_freq = 5
        time_start = -1
        duration = 10
        step_size = 0.00005
        for base_freq in self.base_frequencies:
            stimulus = SinusAmplitudeModulationStimulus(base_freq, contrast, mod_freq, 0, 8)
            self.assertTrue(array_and_time_points_equal(stimulus, time_start, duration, step_size),
                            msg="Stimulus values inconsistent with base freq: {:.2f}".format(base_freq))

    def test_consistency_contrast(self):
        base_freq = 700
        mod_freq = 5
        time_start = -1
        duration = 10
        step_size = 0.00005
        for contrast in self.contrasts:
            stimulus = SinusAmplitudeModulationStimulus(base_freq, contrast, mod_freq, 0, 8)
            self.assertTrue(array_and_time_points_equal(stimulus, time_start, duration, step_size),
                            msg="Stimulus values inconsistent with contrast: {:.2f}".format(contrast))

    def test_consistency_modulation_frequency(self):
        contrast = 0.1
        base_freq = 700
        time_start = -1
        duration = 10
        step_size = 0.00005
        for mod_freq in self.modulation_frequencies:
            stimulus = SinusAmplitudeModulationStimulus(base_freq, contrast, mod_freq, 0, 1)
            self.assertTrue(array_and_time_points_equal(stimulus, time_start, duration, step_size),
                            msg="Stimulus values inconsistent with mod freq: {:.2f}".format(mod_freq))

    def test_consistency_step_size(self):
        contrast = 0.1
        base_freq = 700
        time_start = -1
        duration = 10
        mod_freq = 10
        for step_size in self.step_sizes:
            stimulus = SinusAmplitudeModulationStimulus(base_freq, contrast, mod_freq, 0, 8)
            self.assertTrue(array_and_time_points_equal(stimulus, time_start, duration, step_size),
                            msg="Stimulus values inconsistent with step_size: {:.3f}ms".format(step_size)*1000)

    def test_consistency_time_start(self):
        contrast = 0.1
        base_freq = 700
        mod_freq = 10
        duration = 10
        step_size = 0.00005
        for time_start in self.time_starts:
            stimulus = SinusAmplitudeModulationStimulus(base_freq, contrast, mod_freq, 0, 8)
            self.assertTrue(array_and_time_points_equal(stimulus, time_start, duration, step_size),
                            msg="Stimulus values inconsistent when the time starts at: {:.2f}s".format(time_start))


def array_and_time_points_equal(stimulus, start, duration, step_size):
    precision = 15
    array = np.around(stimulus.as_array(start, duration, step_size), precision)
    time = np.arange(start, start+duration, step_size)
    for i, time_point in enumerate(time):
        value = stimulus.value_at_time_in_s(time_point)
        if array[i] != np.round(value, precision):
            stim_per_point = []
            for t in time:
                stim_per_point.append(stimulus.value_at_time_in_s(t))

            stim_per_point = np.around(np.array(stim_per_point), precision)
            fig, axes = plt.subplots(2, 1, sharex="all")
            axes[0].plot(time, array, label="array")
            axes[0].plot(time, stim_per_point, label="individual")
            axes[0].set_title("stimulus values")
            axes[0].legend()

            axes[1].plot(time, np.array(stim_per_point)-array)
            axes[1].set_title("difference")
            plt.show()

            return False

    # stim_per_point = []
    # for t in time:
    #     stim_per_point.append(stimulus.value_at_time_in_s(t))
    #
    # stim_per_point = np.around(np.array(stim_per_point), precision)
    # fig, axes = plt.subplots(1, 1, sharex="all")
    # axes.plot(time, array, label="array")
    # axes.plot(time, stim_per_point, label="individual")
    # axes.set_title("stimulus values")
    # axes.legend()
    #
    # plt.show()

    return True