class LIFModel:
    # all times in milliseconds
    def __init__(self, mem_res, mem_tau, v_base, v_zero, input_current, threshold, input_offset=0, step_size=0.01):
        self.mem_res = mem_res
        # self.membrane_capacitance = mem_cap
        self.mem_tau = mem_tau  # membrane time constant tau = mem_cap*mem_res
        self.v_base = v_base
        self.v_zero = v_zero
        self.threshold = threshold

        self.step_size = step_size
        self.input_current = input_current
        self.input_offset = input_offset

    def calculate_response(self):
        output_voltage = [self.v_zero]
        spikes = []

        for idx in range(1, len(self.input_current)):
            v_next = self.__calculate_next_step__(output_voltage[idx-1], self.input_current[idx-1])
            if v_next > self.threshold:
                v_next = self.v_base
                spikes.append(True)
            else:
                spikes.append(False)
            output_voltage.append(v_next)

        return output_voltage, spikes

    def set_input_current(self, input_current, offset=0):
        self.input_current = input_current
        self.input_offset = offset

    def __calculate_next_step__(self, current_v, input_i):
        return current_v + (self.step_size * (self.v_base - current_v + self.mem_res * input_i)) / self.mem_tau