from stimuli.AbstractStimulus import AbstractStimulus
from models.AbstractModel import AbstractModel
import numpy as np
from my_util import functions as fu
from numba import jit
from warnings import warn
from collections import OrderedDict


class LifacNoiseModel(AbstractModel):
    # all times in milliseconds
    # possible mem_res: 100 * 1000000 exact value unknown in p-units
    DEFAULT_VALUES = OrderedDict([("mem_tau", 0.015),
                                  ("v_base", 0),
                                  ("v_zero", 0),
                                  ("threshold", 1),
                                  ("v_offset", -10),
                                  ("input_scaling", 60),
                                  ("delta_a", 0.08),
                                  ("tau_a", 0.1),
                                  ("a_zero", 2),
                                  ("noise_strength", 0.05),
                                  ("step_size", 0.00005),
                                  ("dend_tau", 0.001),
                                  ("refractory_period", 0.001)])

    def __init__(self, params: dict = None):
        super().__init__(params)

        if self.parameters["step_size"] > 0.0001:
            warn("LifacNoiseModel: The step size is quite big simulation could fail.")
        self.voltage_trace = []
        self.input_voltage = []
        self.adaption_trace = []
        self.spiketimes = []
        self.stimulus = None
        # self.frequency_trace = []

    def simulate_slow(self, stimulus: AbstractStimulus, total_time_s):

        self.stimulus = stimulus
        time = np.arange(0, total_time_s, self.parameters["step_size"])
        output_voltage = np.zeros(len(time), dtype='float64')
        adaption = np.zeros(len(time), dtype='float64')
        input_voltage = np.zeros(len(time), dtype='float64')
        spiketimes = []

        current_v = self.parameters["v_zero"]
        current_a = self.parameters["a_zero"]
        input_voltage[0] = fu.rectify(stimulus.value_at_time_in_s(time[0]))
        output_voltage[0] = current_v
        adaption[0] = current_a

        for i in range(1, len(time), 1):
            time_point = time[i]
            # rectified input:
            stimulus_strength = self._calculate_input_voltage_step(input_voltage[i - 1],
                                                                   fu.rectify(stimulus.value_at_time_in_s(time_point)))

            v_next = self._calculate_voltage_step(current_v, stimulus_strength - current_a)
            a_next = self._calculate_adaption_step(current_a)

            if len(spiketimes) > 0 and time[i] - spiketimes[-1] < self.parameters["refractory_period"] + self.parameters["step_size"]/2:
                v_next = self.parameters["v_base"]

            if v_next > self.parameters["threshold"]:
                v_next = self.parameters["v_base"]
                spiketimes.append(time_point)
                a_next += self.parameters["delta_a"] / self.parameters["tau_a"]

            output_voltage[i] = v_next
            adaption[i] = a_next
            input_voltage[i] = stimulus_strength

            current_v = v_next
            current_a = a_next

        self.voltage_trace = output_voltage
        self.adaption_trace = adaption
        self.spiketimes = spiketimes
        self.input_voltage = input_voltage

        return output_voltage, spiketimes

    def _calculate_voltage_step(self, current_v, input_v):
        v_base = self.parameters["v_base"]
        step_size = self.parameters["step_size"]
        v_offset = self.parameters["v_offset"]
        mem_tau = self.parameters["mem_tau"]

        noise_strength = self.parameters["noise_strength"]
        noise_value = np.random.normal()
        noise = noise_strength * noise_value / np.sqrt(step_size)

        return current_v + step_size * ((v_base - current_v + v_offset + input_v + noise) / mem_tau)

    def _calculate_adaption_step(self, current_a):
        step_size = self.parameters["step_size"]
        return current_a + (step_size * (-current_a)) / self.parameters["tau_a"]

    def _calculate_input_voltage_step(self, current_i, rectified_input):
        # input_voltage[i] = input_voltage[i - 1] + (-input_voltage[i - 1] + rectified_stimulus_array[i] * input_scaling) / dend_tau
        return current_i + (
                    (-current_i + rectified_input * self.parameters["input_scaling"]) / self.parameters["dend_tau"]) * \
               self.parameters["step_size"]

    def simulate(self, stimulus: AbstractStimulus, total_time_s, time_start=0):

        v_zero = self.parameters["v_zero"]
        a_zero = self.parameters["a_zero"]
        step_size = self.parameters["step_size"]
        threshold = self.parameters["threshold"]
        v_base = self.parameters["v_base"]
        delta_a = self.parameters["delta_a"]
        tau_a = self.parameters["tau_a"]
        v_offset = self.parameters["v_offset"]
        mem_tau = self.parameters["mem_tau"]
        noise_strength = self.parameters["noise_strength"]
        input_scaling = self.parameters["input_scaling"]
        dend_tau = self.parameters["dend_tau"]
        ref_period = self.parameters["refractory_period"]

        rectified_stimulus = rectify_stimulus_array(stimulus.as_array(time_start, total_time_s, step_size))
        parameters = np.array(
            [v_zero, a_zero, step_size, threshold, v_base, delta_a, tau_a, v_offset, mem_tau, noise_strength,
             time_start, input_scaling, dend_tau, ref_period])
        if dend_tau >= step_size:
            voltage_trace, adaption, spiketimes, input_voltage = simulate_fast(rectified_stimulus, total_time_s, parameters)
        else:
            voltage_trace, adaption, spiketimes, input_voltage = simulate_fast_no_dend_tau(rectified_stimulus, total_time_s, parameters)

        self.stimulus = stimulus
        self.input_voltage = input_voltage
        self.voltage_trace = voltage_trace
        self.adaption_trace = adaption
        self.spiketimes = spiketimes

        return voltage_trace, spiketimes

    def min_stimulus_strength_to_spike(self):
        return self.parameters["threshold"] - self.parameters["v_base"]

    def get_sampling_interval(self):
        return self.parameters["step_size"]

    def get_frequency(self):
        # TODO also change simulates_frequency() if any calculation is added!
        raise NotImplementedError("No calculation implemented yet for the frequency.")

    def get_spiketimes(self):
        return self.spiketimes

    def get_voltage_trace(self):
        return self.voltage_trace

    def get_adaption_trace(self):
        return self.adaption_trace

    def simulates_frequency(self) -> bool:
        return False

    def simulates_spiketimes(self) -> bool:
        return True

    def simulates_voltage_trace(self) -> bool:
        return True

    def get_recording_times(self):
        # [delay, stimulus_start, stimulus_duration, time_to_end]
        self.stimulus = AbstractStimulus()
        delay = 0
        start = self.stimulus.get_stimulus_start_s()
        duration = self.stimulus.get_stimulus_duration_s()
        total_time = len(self.voltage_trace) / self.parameters["step_size"]

        return [delay, start, duration, total_time]

    def get_model_copy(self):
        return LifacNoiseModel(self.parameters)

    def get_eodf_scaled_parameters(self, factor):
        scaled_parameters = self.parameters.copy()
        time_param_keys = ["refractory_period", "tau_a", "mem_tau", "dend_tau", "delta_a"]

        for key in time_param_keys:
            scaled_parameters[key] = self.parameters[key] / factor

        return scaled_parameters

    def find_v_offset(self, goal_baseline_frequency, base_stimulus, threshold=2, border=50000):
        test_model = self.get_model_copy()
        simulation_length = 6

        v_search_step_size = 100

        current_v_offset = -400

        current_freq = test_v_offset(test_model, current_v_offset, base_stimulus, simulation_length)

        while current_freq < goal_baseline_frequency:
            if current_v_offset >= border:
                return border
            current_v_offset += v_search_step_size
            current_freq = test_v_offset(test_model, current_v_offset, base_stimulus, simulation_length)

        lower_bound = current_v_offset - v_search_step_size
        upper_bound = current_v_offset

        return binary_search_base_freq(test_model, base_stimulus, goal_baseline_frequency, simulation_length,
                                       lower_bound, upper_bound, threshold)


def binary_search_base_freq(model: LifacNoiseModel, base_stimulus, goal_frequency, simulation_length, lower_bound,
                            upper_bound, threshold):
    counter = 0
    if threshold <= 0:
        raise ValueError("binary_search_base_freq() - LifacNoiseModel: threshold is not allowed to be negative!")
    while True:
        counter += 1
        middle = upper_bound - (upper_bound - lower_bound) / 2
        frequency = test_v_offset(model, middle, base_stimulus, simulation_length)
        # print("offset: {:.1f}, freq: {:.0f}".format(middle, frequency))
        # print('{:.1f}, {:.1f}, {:.1f}, {:.1f} vs {:.1f} '.format(lower_bound, middle, upper_bound, frequency, goal_frequency))

        if abs(frequency - goal_frequency) < threshold:
            return middle
        elif frequency < goal_frequency:
            lower_bound = middle
        elif frequency > goal_frequency:
            upper_bound = middle
        else:
            print('lower bound: {:.1f}, middle: {:.1f}, upper_bound: {:.1f}, frequency: {:.1f} vs goal: {:.1f} '.format(
                lower_bound, middle, upper_bound, frequency, goal_frequency))
            raise ValueError("binary_search_base_freq() - LifacNoiseModel: Goal frequency might be nan?")

        if abs(upper_bound - lower_bound) < 0.0001:
            print("v_offset search stopped. bounds converged! freq: {:.2f}, bounds: {:.0f}"
                  .format(frequency, lower_bound))
            # print(model.parameters)
            warn("Search was stopped. Upper and lower bounds converged without finding a value closer than threshold!")
            return middle


def test_v_offset(model: LifacNoiseModel, v_offset, base_stimulus, simulation_length):
    model.set_variable("v_offset", v_offset)
    try:
        v, spiketimes = model.simulate(base_stimulus, simulation_length)
        # if len(spiketimes) > 0:
            # print("sim length", simulation_length, "last spike",  max(spiketimes), "num of spikes:", len(spiketimes))
        rel_spikes = [s for s in spiketimes if s > simulation_length / 3]

        return len(rel_spikes) / (2/3 * simulation_length)
    except ZeroDivisionError:
        print("divide by zero!")
        freq = 0
    # if freq > 10000:
    #     from IPython import embed
    #     import matplotlib.pyplot as plt
    #     embed()

    return freq


@jit(nopython=True)
def rectify_stimulus_array(stimulus_array: np.ndarray):
    return np.array([x if x > 0 else 0 for x in stimulus_array])


@jit(nopython=True)
def simulate_fast(rectified_stimulus_array, total_time_s, parameters: np.ndarray):
    v_zero = parameters[0]
    a_zero = parameters[1]
    step_size = parameters[2]
    threshold = parameters[3]
    v_base = parameters[4]
    delta_a = parameters[5]
    tau_a = parameters[6]
    v_offset = parameters[7]
    mem_tau = parameters[8]
    noise_strength = parameters[9]
    time_start = parameters[10]
    input_scaling = parameters[11]
    dend_tau = parameters[12]
    ref_period = parameters[13]

    time = np.arange(time_start, total_time_s, step_size)
    length = len(time)
    output_voltage = np.zeros(length)
    adaption = np.zeros(length)
    input_voltage = np.zeros(length)

    spiketimes = []
    output_voltage[0] = v_zero
    adaption[0] = a_zero
    input_voltage[0] = rectified_stimulus_array[0]

    for i in range(1, len(time), 1):

        noise_value = np.random.normal()
        noise = noise_strength * noise_value / np.sqrt(step_size)

        input_voltage[i] = input_voltage[i - 1] + (
                    (-input_voltage[i - 1] + rectified_stimulus_array[i]) / dend_tau) * step_size

        output_voltage[i] = output_voltage[i - 1] + ((v_base - output_voltage[i - 1] + v_offset + (
                    input_voltage[i] * input_scaling) - adaption[i - 1] + noise) / mem_tau) * step_size

        adaption[i] = adaption[i - 1] + ((-adaption[i - 1]) / tau_a) * step_size

        if len(spiketimes) > 0 and time[i] - spiketimes[-1] < ref_period + step_size/2:
            output_voltage[i] = v_base

        if output_voltage[i] > threshold:
            output_voltage[i] = v_base
            spiketimes.append((i * step_size) + time_start)
            adaption[i] += delta_a / tau_a

    return output_voltage, adaption, spiketimes, input_voltage


@jit(nopython=True)
def simulate_fast_no_dend_tau(rectified_stimulus_array, total_time_s, parameters: np.ndarray):
    v_zero = parameters[0]
    a_zero = parameters[1]
    step_size = parameters[2]
    threshold = parameters[3]
    v_base = parameters[4]
    delta_a = parameters[5]
    tau_a = parameters[6]
    v_offset = parameters[7]
    mem_tau = parameters[8]
    noise_strength = parameters[9]
    time_start = parameters[10]
    input_scaling = parameters[11]
    dend_tau = parameters[12]
    ref_period = parameters[13]

    time = np.arange(time_start, total_time_s, step_size)
    length = len(time)
    output_voltage = np.zeros(length)
    adaption = np.zeros(length)
    input_voltage = rectified_stimulus_array

    spiketimes = []
    output_voltage[0] = v_zero
    adaption[0] = a_zero

    for i in range(1, len(time), 1):

        noise_value = np.random.normal()
        noise = noise_strength * noise_value / np.sqrt(step_size)

        output_voltage[i] = output_voltage[i - 1] + ((v_base - output_voltage[i - 1] + v_offset + (
                    input_voltage[i] * input_scaling) - adaption[i - 1] + noise) / mem_tau) * step_size

        adaption[i] = adaption[i - 1] + ((-adaption[i - 1]) / tau_a) * step_size

        if len(spiketimes) > 0 and time[i] - spiketimes[-1] < ref_period + step_size/2:
            output_voltage[i] = v_base

        if output_voltage[i] > threshold:
            output_voltage[i] = v_base
            spiketimes.append((i * step_size) + time_start)
            adaption[i] += delta_a / tau_a

    return output_voltage, adaption, spiketimes, input_voltage