from stimuli.AbstractStimulus import AbstractStimulus
from models.AbstractModel import AbstractModel
import numpy as np
import helperFunctions as hf


class LifacNoiseModel(AbstractModel):
    # all times in milliseconds
    # possible mem_res: 100 * 1000000
    DEFAULT_VALUES = {"mem_tau": 20,
                      "v_base": 0,
                      "v_zero": -1,
                      "threshold": 1,
                      "step_size": 0.01,
                      "delta_a": 0.4,
                      "tau_a": 40,
                      "a_zero": 30,
                      "v_offset": 50,
                      "input_scaling": 1,
                      "noise_strength": 3}
    # membrane time constant tau = mem_cap*mem_res
    def __init__(self, params: dict = None):
        super().__init__(params)

        self.voltage_trace = []
        self.adaption_trace = []
        self.spiketimes = []
        self.stimulus = None
        # self.frequency_trace = []

    def simulate(self, stimulus: AbstractStimulus, total_time_s):
        self.stimulus = stimulus
        output_voltage = []
        adaption = []
        spiketimes = []
        current_v = self.parameters["v_zero"]
        current_a = self.parameters["a_zero"]

        for time_point in np.arange(0, total_time_s*1000, self.parameters["step_size"]):
            # rectified input:
            stimulus_strength = hf.rectify(stimulus.value_at_time_in_ms(time_point))

            v_next = self._calculate_voltage_step(current_v, stimulus_strength - current_a)
            a_next = self._calculate_adaption_step(current_a)

            if v_next > self.parameters["threshold"]:
                v_next = self.parameters["v_base"]
                spiketimes.append(time_point/1000)
                a_next += self.parameters["delta_a"] / (self.parameters["tau_a"] / 1000)

            output_voltage.append(v_next)
            adaption.append(a_next)

            current_v = v_next
            current_a = a_next

        self.voltage_trace = output_voltage
        self.adaption_trace = adaption
        self.spiketimes = spiketimes

        return output_voltage, spiketimes

    def _calculate_voltage_step(self, current_v, input_v):
        v_base = self.parameters["v_base"]
        step_size = self.parameters["step_size"]
        v_offset = self.parameters["V_offset"]
        mem_tau = self.parameters["mem_tau"]

        noise_strength = self.parameters["noise_strength"]
        noise_value = np.random.normal()
        noise = noise_strength * noise_value / np.sqrt(step_size)

        return current_v + step_size * ((v_base - current_v + v_offset + input_v + noise) / mem_tau)

    def _calculate_adaption_step(self, current_a):
        step_size = self.parameters["step_size"]
        return current_a + (step_size * (-current_a)) / self.parameters["tau_a"]

    def min_stimulus_strength_to_spike(self):
        return self.parameters["threshold"] - self.parameters["v_base"]

    def get_sampling_interval(self):
        return self.parameters["step_size"]

    def get_frequency(self):
        # TODO also change simulates_frequency() if any calculation is added!
        raise NotImplementedError("No calculation implemented yet for the frequency.")

    def get_spiketimes(self):
        return self.spiketimes

    def get_voltage_trace(self):
        return self.voltage_trace

    def get_adaption_trace(self):
        return self.adaption_trace

    def simulates_frequency(self) -> bool:
        return False

    def simulates_spiketimes(self) -> bool:
        return True

    def simulates_voltage_trace(self) -> bool:
        return True

    def get_recording_times(self):
        # [delay, stimulus_start, stimulus_duration, time_to_end]
        self.stimulus = AbstractStimulus()
        delay = 0
        start = self.stimulus.get_stimulus_start_s()
        duration = self.stimulus.get_stimulus_duration_s()
        total_time = len(self.voltage_trace) / self.parameters["step_size"]

        return [delay, start, duration, total_time]