from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
import helperFunctions as hF
import Figure_constants as consts
import matplotlib.pyplot as plt
import numpy as np
import models.smallModels as sM


def main():
    # stimulus_development()
    model_adaption_example()
    # model_comparison()
    pass


def isi_development():
    model_params = consts.model_cell_1


    pass


def model_comparison():
    step = 0.00001
    duration = 0.5
    stimulus = np.arange(0, duration, step)
    stimulus[0:8000] = 2
    stimulus[8000:20000] = 0
    stimulus[20000:] = 1

    fig, axes = plt.subplots(5, 2, sharex=True, sharey="col", figsize=consts.FIG_SIZE_LARGE)

    axes[1, 0].set_title("Voltage")
    axes[1, 1].set_title("Frequency")

    axes[0, 0].plot(np.arange(0, duration, step)[:len(stimulus)], stimulus)
    axes[0, 0].set_ylabel("Stimulus")
    axes[0, 1].set_frame_on(False)
    axes[0, 1].set_axis_off()

    v1, spikes = sM.pif_simulation(stimulus, step)
    axes[1, 0].plot(np.arange(0, duration, step)[:len(v1)], v1)
    axes[1, 0].eventplot(spikes, lineoffsets=1.2, linelengths=0.2, colors="black")
    time, freq = hF.calculate_time_and_frequency_trace(spikes, step)
    axes[1, 1].plot(time, freq)
    axes[1, 0].set_ylabel("PIF")

    v1, spikes = sM.lif_simulation(stimulus, step)
    axes[2, 0].plot(np.arange(0, duration, step)[:len(v1)], v1)
    axes[2, 0].eventplot(spikes, lineoffsets=1.2, linelengths=0.2, colors="black")
    time, freq = hF.calculate_time_and_frequency_trace(spikes, step)
    axes[2, 1].plot(time, freq)
    axes[2, 0].set_ylabel("LIF")

    v1, spikes = sM.lifac_simulation(stimulus, step)
    axes[3, 0].plot(np.arange(0, duration, step)[:len(v1)], v1)
    axes[3, 0].eventplot(spikes, lineoffsets=1.2, linelengths=0.2, colors="black")
    time, freq = hF.calculate_time_and_frequency_trace(spikes, step)
    axes[3, 1].plot(time, freq)
    axes[3, 0].set_ylabel("LIFAC")

    v1, spikes = sM.lifac_ref_simulation(stimulus, step)
    axes[4, 0].plot(np.arange(0, duration, step)[:len(v1)], v1)
    axes[4, 0].eventplot(spikes, lineoffsets=1.2, linelengths=0.2, colors="black")
    time, freq = hF.calculate_time_and_frequency_trace(spikes, step)
    axes[4, 1].plot(time, freq)
    axes[4, 0].set_ylabel("LIFAC + ref")
    axes[4, 0].set_xlabel("Time [s]")
    axes[4, 1].set_xlabel("Time [s]")

    # v1, spikes = sM.lifac_ref_noise_simulation(stimulus, step)
    # axes[5, 0].plot(np.arange(0, duration, step)[:len(v1)], v1)
    # axes[5, 0].eventplot(spikes, lineoffsets=1.2, linelengths=0.2, colors="black")
    # time, freq = hF.calculate_time_and_frequency_trace(spikes, step)
    # print(np.mean(freq))
    # axes[5, 1].plot(time, freq)
    plt.tight_layout()
    plt.savefig(consts.SAVE_FOLDER + "model_comparison.pdf")
    plt.close()


def stimulus_development():
    time_start = -0.020
    time_duration = 0.080

    stimulus = SinusoidalStepStimulus(745, 0.2, 0.1, 0.1)
    step_size = 0.000005
    stim_array = stimulus.as_array(time_start, time_duration, step_size)

    rectified = hF.rectify_stimulus_array(stim_array)
    filtered = dendritic_lowpass(rectified, 0.001, step_size)

    fig, axes = plt.subplots(3, 1, figsize=(6, 6), sharex="col")
    time = np.arange(time_start, time_start+time_duration, step_size)

    axes[0].plot(time, stim_array)
    axes[0].set_title("stimulus")

    axes[1].plot(time, rectified)
    axes[1].set_title("rectified stimulus")

    axes[2].plot(time, filtered)
    axes[2].set_title("rectified with dendritic filter")

    axes[0].set_ylim((-1.15, 1.15))
    axes[1].set_ylim((-1.15, 1.15))
    axes[2].set_ylim((-1.15, 1.15))

    for ax in axes:
        ax.set_ylabel("Amplitude [mV]")
    axes[2].set_xlabel("Time [s]")
    axes[0].set_xlim((0, 0.05))
    # axes[2].set_ylim((0, 1.05))

    plt.tight_layout()
    plt.savefig(consts.SAVE_FOLDER + "stimulus_development.pdf")
    plt.close()


def dendritic_lowpass(stimulus, dend_tau, step_size):
    filtered = np.zeros(len(stimulus))
    filtered[0] = stimulus[0]

    for i in range(1, len(stimulus), 1):
        filtered[i] = filtered[i - 1] + ((-filtered[i - 1] + stimulus[i]) / dend_tau) * step_size

    return filtered


def model_adaption_example():
    # TODO find a god example model
    parameter = consts.model_cell_2
    model = LifacNoiseModel(parameter)
    # frequency, contrast, start_time=0, duration=np.inf, amplitude=1)
    frequency = 350
    contrast = 0
    start_time = 5
    duration = 0
    stimulus = SinusoidalStepStimulus(frequency, contrast, start_time, duration)

    time_start = 0
    time_duration = 0.5
    time_step = model.get_sampling_interval()
    v1, spikes = model.simulate(stimulus, total_time_s=time_duration, time_start=time_start)
    adaption = model.get_adaption_trace()
    time = np.arange(time_start, time_start+time_duration, time_step)
    fig, axes = plt.subplots(2, sharex=True, gridspec_kw={'height_ratios': [1, 1]})

    # axes[0].plot(time, stimulus.as_array(time_start, time_duration, time_step))

    start = 0.26
    end = 0.29
    start_idx = int(start / time_step)
    end_idx = int(end / time_step)
    time_part = time[start_idx:end_idx]
    # axes[0].plot(time[start_idx:end_idx], v1[start_idx:end_idx])
    axes[0].eventplot([s for s in spikes if start < s < end], lineoffsets=1.2, linelengths=0.2, colors="black")
    # axes[0].set_ylim((0.5, 1.5))
    # axes[0].set_frame_on(False)
    # axes[0].set_axis_off()
    # axes[0].set_ylabel("Spikes")

    axes[0].plot(time_part, v1[start_idx:end_idx])
    axes[0].set_ylabel("Membrane voltage [mV]")
    # axes[1].plot(time[start_idx:end_idx], adaption[start_idx:end_idx])
    axes[1].plot(time_part, -1*np.array(adaption[start_idx:end_idx]))
    axes[1].set_ylabel("Adaption current [mV]")
    axes[1].set_xlabel("Time [ms]")
    axes[1].set_xlim((start, end))

    plt.savefig(consts.SAVE_FOLDER + "adaptionExample.pdf")
    plt.close()


if __name__ == '__main__':
    main()