import numpy as np
import matplotlib.pyplot as plt
import helperFunctions as hF


def main():
    time = 30
    adaption_time = 1
    step = 0.00005
    stimulus = np.zeros(int((time + adaption_time)/step))

    v1, spikes = pif_simulation(stimulus, step)
    spikes = np.array(spikes)
    count = len(spikes[spikes > adaption_time])
    print("Baseline freq PIF: {:.2f}".format(count / time))

    v1, spikes = lif_simulation(stimulus, step)
    spikes = np.array(spikes)
    count = len(spikes[spikes > adaption_time])
    print("Baseline freq LIF: {:.2f}".format(count / time))

    v1, spikes = lifac_simulation(stimulus, step)
    spikes = np.array(spikes)
    count = len(spikes[spikes > adaption_time])
    print("Baseline freq LIFAC: {:.2f}".format(count / time))

    v1, spikes = lifac_ref_simulation(stimulus, step)
    spikes = np.array(spikes)
    count = len(spikes[spikes > adaption_time])
    print("Baseline freq LIFAC+ref: {:.2f}".format(count / time))


def pif_simulation(stimulus, step_size):

    v_0 = 0
    v_base = 0
    threshold = 1
    v_offset = 0.15
    mem_tau = 0.015

    v_1 = np.zeros(len(stimulus))
    v_1[0] = v_0
    spikes = []

    for i in range(1, len(v_1), 1):
        # dvdt =  (v_offset + stimulus[i]) / mem_tau
        v_1[i] = v_1[i - 1] + ((v_offset + stimulus[i]) / mem_tau) * step_size

        if v_1[i] > threshold:
            v_1[i] = v_base
            spikes.append(i*step_size)

    return v_1, spikes


def lif_simulation(stimulus, step_size):
    v_0 = 0
    v_base = 0
    threshold = 1
    v_offset = 1.001255
    mem_tau = 0.015

    v_1 = np.zeros(len(stimulus))
    v_1[0] = v_0
    spikes = []

    for i in range(1, len(v_1), 1):
        v_1[i] = v_1[i - 1] + ((v_offset - v_1[i-1] + stimulus[i]) / mem_tau) * step_size

        if v_1[i] > threshold:
            v_1[i] = v_base
            spikes.append(i * step_size)

    return v_1, spikes


def lifac_simulation(stimulus, step_size):
    v_0 = 0
    v_base = 0
    threshold = 1
    v_offset = 1.3445
    mem_tau = 0.015
    adaption_tau = 0.1
    adaption_step = 0.05

    adaption = np.zeros(len(stimulus))
    adaption[0] = 0.5
    v_1 = np.zeros(len(stimulus))
    v_1[0] = v_0
    spikes = []

    for i in range(1, len(v_1), 1):
        v_1[i] = v_1[i - 1] + ((v_offset - v_1[i-1] - adaption[i-1] + stimulus[i]) / mem_tau) * step_size
        adaption[i] = adaption[i-1] + (-adaption[i-1] / adaption_tau) * step_size
        if v_1[i] > threshold:
            v_1[i] = v_base
            spikes.append(i * step_size)
            adaption[i] += adaption_step / adaption_tau

    return v_1, spikes


def lifac_ref_simulation(stimulus, step_size):
    v_0 = 0
    v_base = 0
    threshold = 1
    v_offset = 1.3445
    mem_tau = 0.015
    adaption_tau = 0.1
    adaption_step = 0.05
    ref_time = 0.005
    adaption = np.zeros(len(stimulus))
    adaption[0] = 0.5
    v_1 = np.zeros(len(stimulus))
    v_1[0] = v_0
    spikes = []

    for i in range(1, len(v_1), 1):
        if len(spikes) > 0 and i*step_size < spikes[-1] + ref_time:
            v_1[i] = v_base
        else:
            v_1[i] = v_1[i - 1] + ((v_offset - v_1[i-1] - adaption[i-1] + stimulus[i]) / mem_tau) * step_size

        adaption[i] = adaption[i-1] + (-adaption[i-1] / adaption_tau) * step_size
        if v_1[i] > threshold:
            v_1[i] = v_base
            spikes.append(i * step_size)
            adaption[i] += adaption_step / adaption_tau

    return v_1, spikes


def lifac_ref_noise_simulation(stimulus, step_size):
    v_0 = 0
    v_base = 0
    threshold = 1
    v_offset = 1.32
    mem_tau = 0.015
    adaption_tau = 0.1
    adaption_step = 0.05
    ref_time = 0.005

    noise_strength = 0.05

    adaption = np.zeros(len(stimulus))
    adaption[0] = 1
    v_1 = np.zeros(len(stimulus))
    v_1[0] = v_0
    spikes = []

    for i in range(1, len(v_1), 1):

        noise_value = np.random.normal()
        noise = noise_strength * noise_value / np.sqrt(step_size)

        if len(spikes) > 0 and i*step_size < spikes[-1] + ref_time:
            v_1[i] = v_base
        else:
            v_1[i] = v_1[i - 1] + ((v_offset - v_1[i-1] - adaption[i-1] + stimulus[i] + noise) / mem_tau) * step_size

        adaption[i] = adaption[i-1] + (-adaption[i-1] / adaption_tau) * step_size
        if v_1[i] > threshold:
            v_1[i] = v_base
            spikes.append(i * step_size)
            adaption[i] += adaption_step / adaption_tau

    return v_1, spikes


if __name__ == '__main__':
    main()