from stimuli.AbstractStimulus import AbstractStimulus
from models.AbstractModel import AbstractModel
import numpy as np
from functions import line


class FirerateModel(AbstractModel):

    DEFAULT_VALUES = {"function_params": [25, 1],
                      "f_zero": -np.inf,
                      "adaptation_factor": 0.05,
                      "a_zero": 0,
                      "a_tau": 0.02,
                      "step_size": 0.0005}

    def __init__(self, params: dict = None):
        super().__init__(params)

        self.frequency_trace = []
        self.adaption_trace = []
        self.stimulus = None

    def simulate(self, stimulus: AbstractStimulus, total_time_s):
        self.stimulus = stimulus
        if self.parameters["f_zero"] is -np.inf:
            current_freq = self.frequency_step(stimulus.value_at_time_in_ms(0))
        else:
            current_freq = self.parameters["f_zero"]

        current_adaptation = self.parameters["a_zero"]

        freq_trace = []
        a_trace = []

        for time_point in np.arange(0, total_time_s, self.parameters["step_size"]):
            freq_trace.append(current_freq)
            a_trace.append(current_adaptation)

            current_stimulus = stimulus.value_at_time_in_ms(time_point) - current_adaptation
            current_freq = self.frequency_step(current_stimulus)
            current_adaptation = self.adaptation_step(current_adaptation, current_freq)

        self.frequency_trace = freq_trace
        self.adaption_trace = a_trace

        return freq_trace

    def adaptation_step(self, current_a, current_f):
        alpha = self.parameters["adaptation_factor"]
        tau = self.parameters["a_tau"]

        next_a = current_a + ((-current_a + alpha * current_f) / tau) * self.parameters["step_size"]
        return next_a

    def frequency_step(self, current_stimulus):
        params = self.parameters["function_params"]
        return line(current_stimulus, params[0], params[1])

    def simulates_voltage_trace(self) -> bool:
        return False

    def simulates_spiketimes(self) -> bool:
        return False

    def simulates_frequency(self) -> bool:
        return True

    def get_voltage_trace(self):
        raise NotImplementedError("This model type doesn't simulate a voltage trace!")

    def get_spiketimes(self):
        raise NotImplementedError("This model type doesn't simulate spiketimes!")

    def get_frequency(self):
        return self.frequency_trace

    def min_stimulus_strength_to_spike(self):
        pass

    def get_sampling_interval(self):
        return self.parameters["step_size"]