import unittest
from stimuli.StepStimulus import StepStimulus
import numpy as np


class TestStepStimulus(unittest.TestCase):

    def test_time_getters(self):
        starts = [2, -5, 0, 10]
        durations = [2, 1000, 0.5]
        value = 10

        for start in starts:
            for duration in durations:
                stimulus = StepStimulus(start, duration, value)

                self.assertEqual(start, stimulus.get_stimulus_start_s(), "reported start (s) was wrong")
                self.assertEqual(duration, stimulus.get_stimulus_duration_s(), "reported duration (s) was wrong")
                self.assertEqual(start + duration, stimulus.get_stimulus_end_s(), "reported end (s) was wrong")

                self.assertEqual(start * 1000, stimulus.get_stimulus_start_ms(), "reported start (ms) was wrong")
                self.assertEqual(duration * 1000, stimulus.get_stimulus_duration_ms(), "reported duration (ms) was wrong")
                self.assertEqual((start + duration)*1000, stimulus.get_stimulus_end_ms(), "reported end (ms) was wrong")

    def test_duration_must_be_positive(self):
        self.assertRaises(ValueError, StepStimulus, 1, -1, 3)

    def test_value_at(self):
        start = 0
        duration = 2
        value = 5
        base_value = -1
        stimulus = StepStimulus(start, duration, value, base_value)

        for i in np.arange(start-1, start+duration+1, 0.1):
            if i < start or i > start+duration:
                self.assertEqual(stimulus.value_at_time_in_s(i), base_value)
                self.assertEqual(stimulus.value_at_time_in_ms(i*1000), base_value)
            else:
                self.assertEqual(stimulus.value_at_time_in_s(i), value)
                self.assertEqual(stimulus.value_at_time_in_ms(i * 1000), value)

    def test_amplitude(self):
        stim_values = [-10, -5, 0, 15, 20]
        base_values = [-10, -5, -2, 0, 1, 15]

        for s_value in stim_values:
            for b_value in base_values:
                stimulus = StepStimulus(0, 1, s_value, b_value)

                self.assertEqual(stimulus.get_amplitude(), s_value-b_value)


if __name__ == '__main__':
    unittest.main()