from stimuli.AbstractStimulus import AbstractStimulus import numpy as np from numba import jit, njit class SinusoidalStepStimulus(AbstractStimulus): def __init__(self, frequency, contrast, start_time=0, duration=np.inf, amplitude=1): self.contrast = 1 + contrast self.amplitude = amplitude self.frequency = frequency self.start_time = start_time self.duration = duration def value_at_time_in_s(self, time_point): carrier = np.sin(2 * np.pi * self.frequency * time_point) if time_point > self.start_time and time_point < self.start_time + self.duration: return self.amplitude * carrier * self.contrast return self.amplitude * carrier def get_stimulus_start_s(self): return self.start_time def get_stimulus_duration_s(self): return self.duration def get_amplitude(self): return self.contrast def as_array(self, time_start, total_time, step_size): frequency = self.frequency amp = self.amplitude contrast = self.contrast start_time = self.start_time duration = self.duration values = convert_to_array(frequency, amp, contrast, start_time, duration, time_start, total_time, step_size) return values # @jit(nopython=True) # makes it slower? def convert_to_array(frequency, amplitude, contrast, start_time, duration, time_start, total_time, step_size_s): full_time = np.arange(time_start, time_start + total_time, step_size_s) full_carrier = np.sin(2 * np.pi * frequency * full_time) if start_time > time_start+total_time or start_time+duration < time_start: if contrast != 0: print("SinusoidalStepStimulus: converted to array in a range outside of step!") return full_carrier * amplitude else: if start_time >= time_start: am_start = start_time else: am_start = time_start if time_start + total_time >= start_time + duration: am_end = start_time + duration else: am_end = time_start + total_time idx_start = (am_start - time_start) / step_size_s idx_end = (am_end - time_start) / step_size_s if idx_start != round(idx_start) or idx_end != round(idx_end): pass # raise ValueError("Didn't calculate integers when searching the start and end index. start:", idx_start, "end:", idx_end) # print("am_start: {:.0f}, am_end: {:.0f}, length: {:.0f}".format(am_start, am_end, am_end-am_start)) idx_start = int(idx_start) idx_end = int(idx_end) values = full_carrier * amplitude values[idx_start:idx_end] = values[idx_start:idx_end]*contrast return values