from CellData import CellData
from FiCurve import FICurve
from AdaptionCurrent import Adaption
import numpy as np
import matplotlib.pyplot as plt


class NeuronModel:
    KEYS = ["mem_res", "mem_tau", "v_base", "v_zero", "threshold", "step_size"]
    VALUES = [100 * 1000000, 0.1 * 200, 0, 0, 10, 0.01]

    def __init__(self, cell_data: CellData, variables: dict = None):
        self.cell_data = cell_data
        self.fi_curve = FICurve(cell_data)
        self.adaption = Adaption(cell_data, self.fi_curve)

        if variables is not None:
            self._test_given_variables(variables)
            self.variables = variables
        else:
            self.variables = {}
        self._add_standard_variables()

    def __call__(self, stimulus):
        raise NotImplementedError("Soon. sorry!")

    def _approximate_variables_from_data(self):
        # TODO don't return but save in class in some form! approximate/calculate other variables?
        base_input = self._calculate_input_fro_base_frequency()
        return base_input

    def simulate(self, start_v, time_in_ms, stimulus):
        response = []
        spikes = []
        current_v = start_v
        current_a = 0
        base_input = self._calculate_input_fro_base_frequency()

        adaption_values = []
        a_infties = []
        print("base input:", base_input)
        for time_step in np.arange(0, time_in_ms, self.variables["step_size"]):
            stimulus_input = stimulus[int(time_step/self.variables["step_size"])] - current_a

            new_v = self._calculate_next_step(current_v, current_a*base_input, base_input + base_input*stimulus_input)
            new_a, a_infty = self._calculate_adaption_step(current_a, stimulus_input)

            if new_v > self.variables["threshold"]:
                new_v = self.variables["v_base"]
                spikes.append(time_step)
            response.append(new_v)

            adaption_values.append(current_a)
            a_infties.append(a_infty)
            current_v = new_v
            current_a = new_a

        plt.title("Adaption variable")
        plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), np.array(adaption_values), label="adaption")
        plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]), np.array(a_infties), label="a_inf")
        plt.plot(np.arange(0, time_in_ms, self.variables["step_size"]),  stimulus, label="stimulus")
        plt.legend()
        plt.xlabel("time in ms")
        plt.ylabel("value as contrast?")
        plt.show()
        plt.close()

        return response, spikes

    def _calculate_next_step(self, current_v, current_a, input_v):
        step_size = self.variables["step_size"]
        v_base = self.variables["v_base"]
        mem_tau = self.variables["mem_tau"]

        return current_v + (step_size * (- current_v + v_base + input_v - current_a)) / mem_tau

    def _calculate_adaption_step(self, current_a, stimulus_input):
        step_size = self.variables["step_size"]
        tau_a = self.adaption.tau_real
        f_infty_freq = self.fi_curve.get_f_infinity_frequency_at_stimulus_value(stimulus_input)
        a_infinity = stimulus_input - self.fi_curve.get_f_zero_inverse_at_frequency(f_infty_freq)
        return current_a + (step_size * (- current_a + a_infinity)) / tau_a, a_infinity

    def set_variable(self, key, value):
        if key not in self.KEYS:
            raise ValueError("Given key is unknown!\n"
                             "Please check spelling and refer to list NeuronModel.KEYS.")
        self.variables[key] = value

    def set_variables(self, variables: dict):
        self._test_given_variables(variables)

        for k in variables.keys():
            self.variables[k] = variables[k]

    def _calculate_input_fro_base_frequency(self):
        return - self.variables["threshold"] / (
                    np.e ** (-1 / (self.cell_data.get_base_frequency()/1000 * self.variables["mem_tau"])) - 1)

    def _test_given_variables(self, variables: dict):
        for k in variables.keys():
            if k not in self.KEYS:
                raise ValueError("Unknown key in given model variables. \n"
                                 "Please check spelling and refer to list NeuronModel.KEYS.")

    def _add_standard_variables(self):
        for i in range(len(self.KEYS)):
            if self.KEYS[i] not in self.variables:
                self.variables[self.KEYS[i]] = self.VALUES[i]