from stimuli.AbstractStimulus import AbstractStimulus
import numpy as np
from numba import jit, njit
from warnings import warn


class SinusAmplitudeModulationStimulus(AbstractStimulus):

    def __init__(self, carrier_frequency, contrast, modulation_frequency, start_time=0, duration=np.inf, amplitude=1):
        self.contrast = contrast
        self.modulation_frequency = modulation_frequency
        self.amplitude = amplitude
        self.carrier_frequency = carrier_frequency
        self.start_time = start_time
        self.duration = duration

    def value_at_time_in_s(self, time_point):
        carrier = np.sin(2 * np.pi * self.carrier_frequency * time_point)

        if time_point < self.start_time or time_point > self.start_time + self.duration:
            return self.amplitude * carrier

        am = (1 + self.contrast * np.sin(2*np.pi*self.modulation_frequency * time_point))

        return self.amplitude * am * carrier

    def get_stimulus_start_s(self):
        return self.start_time

    def get_stimulus_duration_s(self):
        return self.duration

    def get_amplitude(self):
        return self.contrast

    def as_array(self, time_start, total_time, step_size):
        carrier = self.carrier_frequency
        amp = self.amplitude
        mod_freq = self.modulation_frequency
        contrast = self.contrast
        start_time = self.start_time
        duration = self.duration

        values = convert_to_array(carrier, amp, mod_freq, contrast, start_time, duration, time_start, total_time, step_size)

        return values


# @jit(nopython=True)  # makes it slower?
def convert_to_array(carrier_freq, amplitude, modulation_freq, contrast, start_time, duration, time_start, total_time, step_size_s):
    full_time = np.arange(time_start, time_start + total_time, step_size_s)
    full_carrier = np.sin(2 * np.pi * carrier_freq * full_time)
    if start_time > time_start+duration or start_time+duration < time_start:
        return full_carrier * amplitude
    else:
        if start_time >= time_start:
            am_start = start_time
        else:
            am_start = time_start

        if time_start + total_time >= start_time + duration:
            am_end = start_time + duration
        else:
            am_end = time_start + total_time


        idx_start = (am_start - time_start) / step_size_s
        idx_end = (am_end - time_start) / step_size_s

        if idx_start != round(idx_start) or idx_end != round(idx_end):
            warn("Didn't calculate integers when searching the start and end index. start: {} end: {}".format(idx_start, idx_end))
            # raise ValueError("Didn't calculate integers when searching the start and end index. start:", idx_start, "end:", idx_end)
            # print("am_start: {:.0f}, am_end: {:.0f}, length: {:.0f}".format(am_start, am_end, am_end-am_start))

        idx_start = int(idx_start)
        idx_end = int(idx_end)

        am = 1 + contrast * np.sin(2 * np.pi * modulation_freq * full_time[idx_start:idx_end])

        values = full_carrier * amplitude
        values[idx_start:idx_end] = values[idx_start:idx_end]*am

        return values


    # # if the whole stimulus time has the amplitude modulation just built it at once;
    # if time_start >= start_time and start_time+duration < time_start+total_time:
    #     carrier = np.sin(2 * np.pi * carrier_freq * np.arange(start_time, total_time - start_time, step_size_s))
    #     modulation = 1 + contrast * np.sin(2 * np.pi * modulation_freq * np.arange(start_time, total_time - start_time, step_size_s))
    #     values = amplitude * carrier * modulation
    #     return values
    #
    # # if it is split into parts with and without amplitude modulation built it in parts:
    # values = np.array([])
    #
    # # there is some time before the modulation starts:
    # if time_start < start_time:
    #     carrier_before_am = np.sin(2 * np.pi * carrier_freq * np.arange(time_start, start_time, step_size_s))
    #     values = np.concatenate((values, amplitude * carrier_before_am))
    #
    # # there is at least a second part of the stimulus that contains the amplitude:
    # # time starts before the end of the am and ends after it was started
    # if time_start < start_time+duration and time_start+total_time > start_time:
    #     if duration is np.inf:
    #
    #         carrier_during_am = np.sin(
    #             2 * np.pi * carrier_freq * np.arange(start_time, time_start + total_time, step_size_s))
    #         am = 1 + contrast * np.sin(
    #             2 * np.pi * modulation_freq * np.arange(start_time, time_start + total_time, step_size_s))
    #     else:
    #         carrier_during_am = np.sin(
    #             2 * np.pi * carrier_freq * np.arange(start_time, start_time + duration, step_size_s))
    #         am = 1 + contrast * np.sin(
    #             2 * np.pi * modulation_freq * np.arange(start_time, start_time + duration, step_size_s))
    #     values = np.concatenate((values, amplitude * am * carrier_during_am))
    #
    # else:
    #     if contrast != 0:
    #         print("Given stimulus time parameters (start, total) result in no part of it containing the amplitude modulation!")
    #
    # if time_start+total_time > start_time+duration:
    #     carrier_after_am = np.sin(2 * np.pi * carrier_freq * np.arange(start_time + duration, time_start + total_time, step_size_s))
    #     values = np.concatenate((values, amplitude*carrier_after_am))
    #
    # return values