from stimuli.AbstractStimulus import AbstractStimulus
from models.AbstractModel import AbstractModel
import numpy as np
import functions as fu
from numba import jit
import helperFunctions as hF
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
from scipy.optimize import curve_fit
from warnings import warn
import matplotlib.pyplot as plt

class LifacNoiseModel(AbstractModel):
    # all times in milliseconds
    # possible mem_res: 100 * 1000000 exact value unknown in p-units
    DEFAULT_VALUES = {"mem_tau": 0.015,
                      "v_base": 0,
                      "v_zero": 0,
                      "threshold": 1,
                      "v_offset": -10,
                      "input_scaling": 60,
                      "delta_a": 0.08,
                      "tau_a": 0.1,
                      "a_zero": 2,
                      "noise_strength": 0.05,
                      "step_size": 0.00005,
                      "dend_tau": 0.001}

    def __init__(self, params: dict = None):
        super().__init__(params)

        if self.parameters["step_size"] > 0.0001:
            warn("LifacNoiseModel: The step size is quite big simulation could fail.")
        self.voltage_trace = []
        self.adaption_trace = []
        self.spiketimes = []
        self.stimulus = None
        # self.frequency_trace = []

    def simulate(self, stimulus: AbstractStimulus, total_time_s):

        self.stimulus = stimulus
        time = np.arange(0, total_time_s, self.parameters["step_size"])
        output_voltage = np.zeros(len(time), dtype='float64')
        adaption = np.zeros(len(time), dtype='float64')
        spiketimes = []

        current_v = self.parameters["v_zero"]
        current_a = self.parameters["a_zero"]
        output_voltage[0] = current_v
        adaption[0] = current_a

        for i in range(1, len(time), 1):
            time_point = time[i]
            # rectified input:
            stimulus_strength = fu.rectify(stimulus.value_at_time_in_s(time_point)) * self.parameters["input_scaling"]

            v_next = self._calculate_voltage_step(current_v, stimulus_strength - current_a)
            a_next = self._calculate_adaption_step(current_a)

            if v_next > self.parameters["threshold"]:
                v_next = self.parameters["v_base"]
                spiketimes.append(time_point)
                a_next += self.parameters["delta_a"] / self.parameters["tau_a"]

            output_voltage[i] = v_next
            adaption[i] = a_next

            current_v = v_next
            current_a = a_next

        self.voltage_trace = output_voltage
        self.adaption_trace = adaption
        self.spiketimes = spiketimes

        return output_voltage, spiketimes

    def _calculate_voltage_step(self, current_v, input_v):
        v_base = self.parameters["v_base"]
        step_size = self.parameters["step_size"]
        v_offset = self.parameters["v_offset"]
        mem_tau = self.parameters["mem_tau"]

        noise_strength = self.parameters["noise_strength"]
        noise_value = np.random.normal()
        noise = noise_strength * noise_value / np.sqrt(step_size)

        return current_v + step_size * ((v_base - current_v + v_offset + input_v + noise) / mem_tau)

    def _calculate_adaption_step(self, current_a):
        step_size = self.parameters["step_size"]
        return current_a + (step_size * (-current_a)) / self.parameters["tau_a"]

    def simulate_fast(self, stimulus: AbstractStimulus, total_time_s, time_start=0):

        v_zero = self.parameters["v_zero"]
        a_zero = self.parameters["a_zero"]
        step_size = self.parameters["step_size"]
        threshold = self.parameters["threshold"]
        v_base = self.parameters["v_base"]
        delta_a = self.parameters["delta_a"]
        tau_a = self.parameters["tau_a"]
        v_offset = self.parameters["v_offset"]
        mem_tau = self.parameters["mem_tau"]
        noise_strength = self.parameters["noise_strength"]
        input_scaling = self.parameters["input_scaling"]
        dend_tau = self.parameters["dend_tau"]

        rectified_stimulus = rectify_stimulus_array(stimulus.as_array(time_start, total_time_s, step_size))
        parameters = np.array([v_zero, a_zero, step_size, threshold, v_base, delta_a, tau_a, v_offset, mem_tau, noise_strength, time_start, input_scaling, dend_tau])

        voltage_trace, adaption, spiketimes = simulate_fast(rectified_stimulus, total_time_s, parameters)

        self.stimulus = stimulus
        self.voltage_trace = voltage_trace
        self.adaption_trace = adaption
        self.spiketimes = spiketimes

        return voltage_trace, spiketimes

    def min_stimulus_strength_to_spike(self):
        return self.parameters["threshold"] - self.parameters["v_base"]

    def get_sampling_interval(self):
        return self.parameters["step_size"]

    def get_frequency(self):
        # TODO also change simulates_frequency() if any calculation is added!
        raise NotImplementedError("No calculation implemented yet for the frequency.")

    def get_spiketimes(self):
        return self.spiketimes

    def get_voltage_trace(self):
        return self.voltage_trace

    def get_adaption_trace(self):
        return self.adaption_trace

    def simulates_frequency(self) -> bool:
        return False

    def simulates_spiketimes(self) -> bool:
        return True

    def simulates_voltage_trace(self) -> bool:
        return True

    def get_recording_times(self):
        # [delay, stimulus_start, stimulus_duration, time_to_end]
        self.stimulus = AbstractStimulus()
        delay = 0
        start = self.stimulus.get_stimulus_start_s()
        duration = self.stimulus.get_stimulus_duration_s()
        total_time = len(self.voltage_trace) / self.parameters["step_size"]

        return [delay, start, duration, total_time]

    def get_model_copy(self):
        return LifacNoiseModel(self.parameters)

    def calculate_baseline_markers(self, stimulus_freq, max_lag=1, simulation_time=30):
        """
        calculates the baseline markers baseline frequency, vector strength and serial correlation
        based on simulated 30 seconds with a standard Sinusoidal stimulus with the given frequency

        :return: baseline_freq, vs, sc
        """
        base_stimulus = SinusoidalStepStimulus(stimulus_freq, 0)
        _, spiketimes = self.simulate_fast(base_stimulus, simulation_time)
        time_x = 5
        baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, time_x)

        if baseline_freq < 1:
            return baseline_freq, 0, [0]*max_lag

        else:
            time_trace = np.arange(0, 30, self.get_sampling_interval())
            stimulus_array = base_stimulus.as_array(0, 30, self.get_sampling_interval())

            vector_strength = hF.calculate_vector_strength_from_spiketimes(time_trace, stimulus_array, spiketimes, self.get_sampling_interval())
            serial_correlation = hF.calculate_serial_correlation(np.array(spiketimes), max_lag)

            return baseline_freq, vector_strength, serial_correlation

    def calculate_fi_markers(self, contrasts, stimulus_freq):
        """
        calculates the fi markers f_infinity, f_infinity_slope for given contrasts
        based on simulated 2 seconds for each contrast
        :return: f_inf_values_list, f_inf_slope
        """
        stimulus_start = 0.3
        stimulus_duration = 1
        f_infinities = []
        for contrast in contrasts:
            stimulus = SinusoidalStepStimulus(stimulus_freq, contrast, stimulus_start, stimulus_duration)
            _, spiketimes = self.simulate_fast(stimulus, stimulus_start*2+stimulus_duration)
            time, freq = hF.calculate_time_and_frequency_trace(spiketimes, self.get_sampling_interval())
            f_inf = hF.detect_f_infinity_in_freq_trace(time, freq, stimulus_start, stimulus_duration, self.get_sampling_interval())
            f_infinities.append(f_inf)

        popt = hF.fit_clipped_line(contrasts, f_infinities)

        f_infinities_slope = popt[0]

        return f_infinities, f_infinities_slope

    def calculate_fi_curve(self, contrasts, stimulus_freq):

        max_time_constant = max([self.parameters["tau_a"], self.parameters["mem_tau"]])
        factor_to_equilibrium = 5
        stim_duration = max_time_constant * factor_to_equilibrium
        stim_start = max_time_constant * factor_to_equilibrium
        total_simulation_time = max_time_constant * factor_to_equilibrium * 3
        # print("Total simulation time (vs 2.5) {:.2f}".format(total_simulation_time))

        sampling_interval = self.get_sampling_interval()
        f_infinities = []
        f_zeros = []
        f_baselines = []
        import matplotlib.pyplot as plt
        for c in contrasts:
            stimulus = SinusoidalStepStimulus(stimulus_freq, c, stim_start, stim_duration)
            _, spiketimes = self.simulate_fast(stimulus, total_simulation_time)

            time, frequency = hF.calculate_time_and_frequency_trace(spiketimes, sampling_interval)
            # if c == contrasts[0] or c == contrasts[-1]:
            #    plt.plot(frequency)
            #    plt.show()
            f_inf = hF.detect_f_infinity_in_freq_trace(time, frequency, stim_start, stim_duration, sampling_interval)
            f_infinities.append(f_inf)

            f_zero = hF.detect_f_zero_in_frequency_trace(time, frequency, stim_start, sampling_interval)
            f_zeros.append(f_zero)

            f_baseline = hF.detect_f_baseline_in_freq_trace(time, frequency, stim_start, sampling_interval)
            f_baselines.append(f_baseline)


            # fig, axes = plt.subplots(2, 1, sharex="all")
            # stim_time = np.arange(0,3.5, sampling_interval)
            # axes[0].set_title("Contrast: " + str(c))
            # axes[0].plot(stim_time, [stimulus.value_at_time_in_s(t) for t in stim_time])  # stimulus.as_array(0, 3.5, sampling_interval))
            #
            # axes[1].plot(time, frequency)
            # axes[1].plot((time[0], time[-1]), (f_inf, f_inf), label="inf")
            # axes[1].plot((time[0], time[-1]), (f_zero, f_zero), label="zero")
            # axes[1].plot((time[0], time[-1]), (f_baseline, f_baseline), label="base")
            # plt.legend()
            # plt.show()

        return f_baselines, f_zeros, f_infinities

    def find_v_offset(self, goal_baseline_frequency, base_stimulus, threshold=2, border=50000):
        test_model = self.get_model_copy()
        simulation_length = 5

        v_search_step_size = 100

        current_v_offset = -400

        current_freq = test_v_offset(test_model, current_v_offset, base_stimulus, simulation_length)

        while current_freq < goal_baseline_frequency:
            if current_v_offset >= border:
                return border
            current_v_offset += v_search_step_size
            current_freq = test_v_offset(test_model, current_v_offset, base_stimulus, simulation_length)

        lower_bound = current_v_offset - v_search_step_size
        upper_bound = current_v_offset

        return binary_search_base_freq(test_model, base_stimulus, goal_baseline_frequency, simulation_length, lower_bound, upper_bound, threshold)


def binary_search_base_freq(model: LifacNoiseModel, base_stimulus, goal_frequency, simulation_length, lower_bound, upper_bound, threshold):
    counter = 0
    if threshold <= 0:
        raise ValueError("binary_search_base_freq() - LifacNoiseModel: threshold is not allowed to be negative!")
    while True:
        counter += 1
        middle = upper_bound - (upper_bound - lower_bound)/2
        frequency = test_v_offset(model, middle, base_stimulus, simulation_length)

        # print('{:.1f}, {:.1f}, {:.1f}, {:.1f} vs {:.1f} '.format(lower_bound, middle, upper_bound, frequency, goal_frequency))

        if abs(frequency - goal_frequency) < threshold:
            return middle
        elif frequency < goal_frequency:
            lower_bound = middle
        elif frequency > goal_frequency:
            upper_bound = middle
        else:
            print('lower bound: {:.1f}, middle: {:.1f}, upper_bound: {:.1f}, frequency: {:.1f} vs goal: {:.1f} '.format(lower_bound, middle, upper_bound, frequency, goal_frequency))
            raise ValueError("binary_search_base_freq() - LifacNoiseModel: Goal frequency might be nan?")

        if abs(upper_bound-lower_bound) < 0.0001:
            warn("Search was stopped no value was found!")
            return middle


def test_v_offset(model: LifacNoiseModel, v_offset, base_stimulus, simulation_length):
    model.set_variable("v_offset", v_offset)
    try:
        v, spiketimes = model.simulate_fast(base_stimulus, simulation_length)

        freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, simulation_length / 3)
    except ZeroDivisionError:
        print("divide by zero!")
        freq = 0
    # if freq > 10000:
    #     from IPython import embed
    #     import matplotlib.pyplot as plt
    #     embed()

    return freq


@jit(nopython=True)
def rectify_stimulus_array(stimulus_array: np.ndarray):
    return np.array([x if x > 0 else 0 for x in stimulus_array])


@jit(nopython=True)
def simulate_fast(rectified_stimulus_array, total_time_s, parameters: np.ndarray):
    v_zero = parameters[0]
    a_zero = parameters[1]
    step_size = parameters[2]
    threshold = parameters[3]
    v_base = parameters[4]
    delta_a = parameters[5]
    tau_a = parameters[6]
    v_offset = parameters[7]
    mem_tau = parameters[8]
    noise_strength = parameters[9]
    time_start = parameters[10]
    input_scaling = parameters[11]
    dend_tau = parameters[12]

    time = np.arange(time_start, total_time_s, step_size)
    length = len(time)
    output_voltage = np.zeros(length)
    adaption = np.zeros(length)
    input_voltage = np.zeros(length)

    spiketimes = []
    output_voltage[0] = v_zero
    adaption[0] = a_zero
    input_voltage[0] = rectified_stimulus_array[0]

    for i in range(1, len(time), 1):

        noise_value = np.random.normal()
        noise = noise_strength * noise_value / np.sqrt(step_size)

        input_voltage[i] = input_voltage[i - 1] + (-input_voltage[i - 1] + rectified_stimulus_array[i] * input_scaling) / dend_tau
        output_voltage[i] = output_voltage[i-1] + ((v_base - output_voltage[i-1] + v_offset + (rectified_stimulus_array[i] * input_scaling) - adaption[i-1] + noise) / mem_tau) * step_size
        adaption[i] = adaption[i-1] + ((-adaption[i-1]) / tau_a) * step_size

        if output_voltage[i] > threshold:
            output_voltage[i] = v_base
            spiketimes.append(i*step_size)
            adaption[i] += delta_a / tau_a

    return output_voltage, adaption, spiketimes