from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus

import matplotlib.pyplot as plt
import numpy as np


def main():
    plot_model_example()
    pass


def plot_model_example():

    parameter = {}
    model = LifacNoiseModel(parameter)
    # frequency, contrast, start_time=0, duration=np.inf, amplitude=1)
    frequency = 350
    contrast = 0
    start_time = 5
    duration = 0
    stimulus = SinusoidalStepStimulus(frequency, contrast, start_time, duration)

    time_start = 0
    time_duration = 0.5
    time_step = model.get_sampling_interval()
    v1, spikes = model.simulate_fast(stimulus, total_time_s=time_duration, time_start=time_start)
    adaption = model.get_adaption_trace()
    time = np.arange(time_start, time_start+time_duration, time_step)
    fig, axes = plt.subplots(2, sharex=True)

    # axes[0].plot(time, stimulus.as_array(time_start, time_duration, time_step))

    start = 0.26
    end = 0.29
    start_idx = int(start / time_step)
    end_idx = int(end / time_step)
    time_part = np.arange(0, end_idx-start_idx, 1) * time_step *1000
    # axes[0].plot(time[start_idx:end_idx], v1[start_idx:end_idx])
    axes[0].plot(time_part, v1[start_idx:end_idx])
    axes[0].set_ylabel("Membrane voltage [mV]")
    # axes[1].plot(time[start_idx:end_idx], adaption[start_idx:end_idx])
    axes[1].plot(time_part, adaption[start_idx:end_idx])
    axes[1].set_ylabel("Adaption current [mV]")
    axes[1].set_xlabel("Time [ms]")
    axes[1].set_xlim((0, 30))
    plt.show()
    plt.close()





if __name__ == '__main__':
    main()