from models.FirerateModel import FirerateModel
from models.LIFACnoise import LifacNoiseModel
from stimuli.StepStimulus import StepStimulus
import helperFunctions as hf
import numpy as np
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import functions as fu


def main():
    values = [1]  # np.arange(5, 40, 1)
    parameter = "currently not active"
    for value in values:
        lifac_model = LifacNoiseModel({"delta_a": 0})
        # lifac_model.set_variable(parameter, value)
        stimulus_strengths = np.arange(50, 60, 1)

        line_vars = find_fitting_line(lifac_model, stimulus_strengths)
        relation = find_relation(lifac_model, line_vars, stimulus_strengths, confirm=True)

        print(parameter, value)
        print(relation)


def find_fitting_line(lifac_model, stimulus_strengths):
    # Requires a lifac model with adaption delta_a = 0, so just the base is fit
    frequencies = []

    duration = 0.2
    for stim_strength in stimulus_strengths:
        lifac_model.simulate_slow(StepStimulus(0, duration, stim_strength), duration)

        spiketimes = lifac_model.get_spiketimes()
        if len(spiketimes) == 0:
            frequencies.append(0)
            continue
        time, freq = hf.calculate_time_and_frequency_trace(spiketimes, lifac_model.get_sampling_interval())

        if len(freq) == 0:
            frequencies.append(0)
        else:
            frequencies.append(freq[-1])

    popt, pcov = curve_fit(fu.line, stimulus_strengths, frequencies)
    print("line:", popt)
    # popt2, pcov = curve_fit(fu.full_boltzmann, stimulus_strengths, frequencies, p0=[700, 0, 5, 25], bounds=([0, 0, -np.inf, -np.inf], [3000, 0.001, np.inf, np.inf]))
    # print("boltzmann:", popt2)
    # plt.plot(stimulus_strengths, frequencies)
    # plt.plot(stimulus_strengths, [fu.line(x, popt[0], popt[1]) for x in stimulus_strengths], '.')
    # plt.plot(stimulus_strengths, [fu.full_boltzmann(x, popt2[0], popt2[1], popt2[2], popt2[3]) for x in stimulus_strengths], 'o')
    # plt.show()
    return popt  # , popt2


def find_relation(lifac, line_vars, stimulus_strengths, parameter="", value=0, confirm=False):
    # boltzmann_vars = [2.00728705e+02, 1.09905953e-12, 1.03639686e-01, 2.55002788e+01]
    # line_vars = [5.10369405, -29.79774806]
    # example values for base lifac (15.1.20) and stimulus 20-32

    duration = 0.4

    lifac_adaption_strength_range = np.arange(0, 0.31, 0.05)
    firerate_adaption_variables = []
    for lifac_adaption_strength in lifac_adaption_strength_range:
        print(lifac_adaption_strength)
        lifac.set_variable("delta_a", lifac_adaption_strength)
        lifac.set_variable("tau_a", 40)

        adapted_frequencies = []
        firerate_adaption_strengths = []
        for stim in stimulus_strengths:
            #print("stim:", stim)
            stimulus = StepStimulus(0, duration, stim)
            lifac.simulate_slow(stimulus, duration)
            spiketimes = lifac.get_spiketimes()
            time, freq = hf.calculate_time_and_frequency_trace(spiketimes, lifac.get_sampling_interval())

            if len(freq) == 0:
                adapted_frequencies.append(0)
                goal_adapted_freq = 0
            else:
                adapted_frequencies.append(freq[-1])
                goal_adapted_freq = freq[-1]

            # assume fitted linear firing rate as basis of the fire-rate model:
            stimulus_strength_after_adaption = fu.inverse_line(goal_adapted_freq, line_vars[0], line_vars[1])

            # needed adaption strength
            adaption_strength = stim - stimulus_strength_after_adaption
            # adaption variable in model:
            firerate_adaption = adaption_strength / goal_adapted_freq

            # test in model if calculated
            if confirm:
                test_adaption_strength_in_firerate_model(line_vars, firerate_adaption, stimulus, goal_adapted_freq)

            firerate_adaption_strengths.append(firerate_adaption)

        firerate_adaption_variables.append(firerate_adaption_strengths)

        # plt.plot(stimulus_range, firerate_adaption_strength, label=str(lifac_adaption_strength))
        # plt.show()

    for i in range(len(lifac_adaption_strength_range)):
        plt.plot([lifac_adaption_strength_range[i]+p*0.001 for p in range(len(stimulus_strengths))], firerate_adaption_variables[i])

    mean_firerate_adaption_value = [np.median(strengths) for strengths in firerate_adaption_variables]

    l_vars, x = curve_fit(fu.line, lifac_adaption_strength_range, mean_firerate_adaption_value)
    plt.plot(lifac_adaption_strength_range, mean_firerate_adaption_value, label="slope:" + str(round(l_vars[0], 5)))
    plt.title("Relation of adaption strength variables:\n Colored points values for different stimulus strengths")
    plt.xlabel("lifac adaption strength: delta_a")
    plt.ylabel("firerate adaption strength: alpha")
    plt.legend()
    if parameter != "":
        plt.savefig("figures/adaption_relation_" + parameter + "_" + str(value) + ".png")
    else:
        plt.savefig("figures/adaption_relation.png")
    plt.close()
    popt, pcov = curve_fit(fu.line, lifac_adaption_strength_range, mean_firerate_adaption_value)

    # print(popt)
    return popt


def test_adaption_strength_in_firerate_model(line_vars, adaption_strength, stimulus, expected_freq):
    params = {"function_params": line_vars, "adaptation_factor": adaption_strength, "a_tau": 10}
    model = FirerateModel(params)

    model.simulate(stimulus, 0.2)
    freq = model.get_frequency()[-1]
    diff = expected_freq - freq
    if diff > 0.00001 * expected_freq:
        print("expected freq:", expected_freq, "=?=", str(freq), ":", str(diff < 0.00001 * expected_freq))


def test_firerate_model( boltzmann_vars):
    fr_model = FirerateModel(params={"function_params": boltzmann_vars, "adaptation_factor": 0})

    frequencies = []
    stim_strengths = np.arange(0, 50, 0.5)
    duration = 0.5
    for stim_strength in stim_strengths:
        fr_model.simulate(StepStimulus(0, duration, stim_strength), duration)

        frequencies.append(fr_model.get_frequency()[-1])

    plt.plot(stim_strengths, frequencies)
    plt.plot(np.arange(20, 32, 1), [fu.line(x, 5.10369, -29.7977481) for x in np.arange(20, 32, 1)], 'o')
    plt.show()


if __name__ == '__main__':
    main()