from stimuli.AbstractStimulus import AbstractStimulus
import numpy as np


class SinusAmplitudeModulationStimulus(AbstractStimulus):

    def __init__(self, carrier_frequency, contrast, modulation_frequency, amplitude=1, start_time=0, duration=np.inf):
        self.contrast = contrast
        self.modulation_frequency = modulation_frequency
        self.amplitude = amplitude
        self.carrier_frequency = carrier_frequency
        self.start_time = start_time
        self.duration = duration

    def value_at_time_in_s(self, time_point):
        if time_point < self.start_time or time_point > self.start_time + self.duration:
            return 0

        am = (1 + self.contrast * np.sin(2*np.pi*self.modulation_frequency * time_point))
        carrier = np.sin(2*np.pi*self.carrier_frequency*time_point)
        return self.amplitude * am * carrier

    def get_stimulus_start_ms(self):
        return self.start_time

    def get_stimulus_duration_ms(self):
        return self.duration

    def get_amplitude(self):
        return self.contrast