import numpy as np
from warnings import warn
from thunderfish.eventdetection import threshold_crossing_times, threshold_crossings, detect_peaks
from scipy.optimize import curve_fit
import functions as fu
from numba import jit
import matplotlib.pyplot as plt
import time


def plot_errors(list_errors, save_path=None):
    names = ["error_vs", "error_sc", "error_cv", "rms_isi_hist", "error_bursty",
             "error_f_inf", "error_f_inf_s", "error_f_zero", "error_f_zero_s_straight", "error_f0_curve"]
    data = np.array(list_errors)

    fig, axes = plt.subplots(2, 5, figsize=(10, 8))

    for i in range(10):
        col = i % 5
        row = int(i/5.0)

        axes[row, col].hist(data[:, i])
        axes[row, col].set_title(names[i])
        axes[row, col].set_yscale('log')

    if save_path is None:
        plt.show()
    else:
        plt.savefig(save_path + "error_distribution.png")
    plt.close()


def fit_clipped_line(x, y):
    popt, pcov = curve_fit(fu.clipped_line, x, y)

    return popt


def fit_boltzmann(x, y):
    max_f0 = float(max(y))
    if max_f0 == 0:
        return [0, 0, 0, 0]

    min_f0 = 0.1  # float(min(self.f_zeros))
    mean_int = float(np.mean(x))

    total_increase = max_f0 - min_f0
    total_change_int = max(x) - min(x)
    start_k = float((total_increase / total_change_int * 4) / max_f0)

    try:
        popt, pcov = curve_fit(fu.full_boltzmann, x, y,
                               p0=(max_f0, min_f0, start_k, mean_int),
                               maxfev=10000, bounds=([0, 0, -np.inf, -np.inf], [np.inf, np.inf, np.inf, np.inf]))
    except RuntimeError as e:
        print("Error in fit boltzmann: ", str(e))
        print("x_values:", x)
        print("y_values:", y)
        return [0, 0, 0, 0]
    return popt


@jit(nopython=True)
def rectify_stimulus_array(stimulus_array: np.ndarray):
    return np.array([x if x > 0 else 0 for x in stimulus_array])


def merge_similar_intensities(intensities, spiketimes, trans_amplitudes):
    i = 0

    diffs = np.diff(sorted(intensities))
    margin = np.mean(diffs) * 0.6666

    while True:
        if i >= len(intensities):
            break
        intensities, spiketimes, trans_amplitudes = merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, i, margin)
        i += 1

        # Sort the lists so that intensities are increasing
        x = [list(x) for x in zip(*sorted(zip(intensities, spiketimes), key=lambda pair: pair[0]))]
        intensities = x[0]
        spiketimes = x[1]

    return intensities, spiketimes, trans_amplitudes


def merge_intensities_similar_to_index(intensities, spiketimes, trans_amplitudes, index, margin):
    intensity = intensities[index]

    indices_to_merge = []
    for i in range(index+1, len(intensities)):
        if np.abs(intensities[i]-intensity) < margin:
            indices_to_merge.append(i)

    if len(indices_to_merge) != 0:
        indices_to_merge.reverse()

        trans_amplitude_values = [trans_amplitudes[k] for k in indices_to_merge]

        all_the_same = True
        for j in range(1, len(trans_amplitude_values)):
            if not trans_amplitude_values[0] == trans_amplitude_values[j]:
                all_the_same = False
                break

        if all_the_same:
            for idx in indices_to_merge:
                del trans_amplitudes[idx]
        else:
            raise RuntimeError("Trans_amplitudes not the same....")
        for idx in indices_to_merge:
            spiketimes[index].extend(spiketimes[idx])
            del spiketimes[idx]
            del intensities[idx]

    return intensities, spiketimes, trans_amplitudes


def all_calculate_mean_isi_frequency_traces(spiketimes, sampling_interval, stimulus_start=0, time_in_ms=False):
    """
    Expects spiketimes to be a 3dim list with the first dimension being the trial
    the second the count of runs of spikes and the last the individual spikes_times:
     [[[trial1-run1-spike1, trial1-run1-spike2, ...],[trial1-run2-spike1, ...]],[[trial2-run1-spike1, ...], [..]]]
    :param stimulus_start: the time point at which the actual stimulus starts
    :param spiketimes: time points of action potentials
    :param sampling_interval: the sampling interval used / will also be used for the frequency trace
    :param time_in_ms: whether the time is in ms or seconds
    :return: the mean frequency trace for each trial and its time trace
    """
    times = []
    mean_frequencies = []

    for i in range(len(spiketimes)):
        trial_time_trace = []
        trial_freq_trace = []
        for j in range(len(spiketimes[i])):
            time, isi_freq = calculate_time_and_frequency_trace(spiketimes[i][j], sampling_interval, time_in_ms)

            if time[0] > stimulus_start:
                print("Trial not used as its frequency trace started after the stimulus start!")
                continue
            trial_freq_trace.append(isi_freq)
            trial_time_trace.append(time)

        time, mean_freq = calculate_mean_of_frequency_traces(trial_time_trace, trial_freq_trace, sampling_interval)
        times.append(time)
        mean_frequencies.append(mean_freq)

    return times, mean_frequencies


def calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms=False):
    """
    Calculates the frequency over time according to the inter spike intervals.

    :param spiketimes: sorted time points spikes were measured array_like
    :param sampling_interval: the sampling interval in which the frequency should be given back
    :param time_in_ms: whether the time is in ms or in s for BOTH the spiketimes and the sampling interval
    :return: an np.array with the isi frequency starting at the time of first spike and ending at the time of the last spike
    """

    if len(spiketimes) <= 1:
        return []

    isis = np.diff(spiketimes)
    if sampling_interval > round(min(isis), 7):
        raise ValueError("The sampling interval is bigger than the some isis! cannot accurately compute the trace.\n"
                         "Sampling interval {:.5f}, smallest isi: {:.5f}".format(sampling_interval, min(isis)))

    if time_in_ms:
        isis = isis / 1000
        sampling_interval = sampling_interval / 1000

    full_frequency = np.array([])
    for isi in isis:
        if isi < 0:
            raise ValueError("There was a negative interspike interval, the spiketimes need to be sorted")
        if isi == 0:
            warn("An ISI was zero in FiCurve:__calculate_mean_isi_frequency__()")
            print("ISI was zero:", spiketimes)
            continue
        freq = 1 / isi
        frequency_step = np.full(int(round(isi * (1 / sampling_interval))), freq)
        full_frequency = np.concatenate((full_frequency, frequency_step))

    return full_frequency


def gaussian_kernel(sigma, dt):
    x = np.arange(-4. * sigma, 4. * sigma, dt)
    y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
    return y


def calculate_gauss_convolve_freq(spiketimes, duration, sampling_interval, gauss_sigma):
    binary = np.zeros(int(np.rint(duration / sampling_interval)))
    g = gaussian_kernel(gauss_sigma, sampling_interval)
    for s in spiketimes:
        binary[int(np.rint(s / sampling_interval))] = 1
    rate = np.convolve(binary, g, mode='same')
    return rate


def calculate_time_and_frequency_trace(spiketimes, sampling_interval, time_in_ms=False):
    if len(spiketimes) < 2:
        return [0], [0]
        # raise ValueError("Cannot compute a time and frequency vector with fewer than 2 spikes")

    frequency = calculate_isi_frequency_trace(spiketimes, sampling_interval, time_in_ms)

    time = np.arange(spiketimes[0], spiketimes[-1], sampling_interval)

    if len(time) != len(frequency):
        if len(time) > len(frequency):
            time = time[:len(frequency)]

    return time, frequency


def calculate_mean_of_frequency_traces(trial_time_traces, trial_frequency_traces, sampling_interval):
    """
    calculates the mean_trace of the given frequency traces -> mean at each time point
    for traces starting at different times
    :param trial_time_traces:
    :param trial_frequency_traces:
    :param sampling_interval:
    :return:
    """
    ends = [t[-1] for t in trial_time_traces]
    starts = [t[0] for t in trial_time_traces]
    latest_start = max(starts)
    earliest_end = min(ends)

    length = int(round((earliest_end - latest_start) / sampling_interval))
    shortened_time = (np.arange(0, length) * sampling_interval) + latest_start

    shortened_freqs = []
    for i in range(len(trial_frequency_traces)):
        start_idx = int(round((latest_start - trial_time_traces[i][0]) / sampling_interval))
        end_idx = int(round((earliest_end - trial_time_traces[i][0]) / sampling_interval))

        shortened_freqs.append(trial_frequency_traces[i][start_idx:end_idx])

    mean_freq = [sum(e) / len(e) for e in zip(*shortened_freqs)]

    # for i in range(len(trial_time_traces)):
    #     if i > 5:
    #         break
    #     plt.plot(trial_time_traces[i], trial_frequency_traces[i])
    #
    # plt.plot(shortened_time, mean_freq, color="black")
    # plt.show()
    # plt.close()


    # if len(mean_freq) == len(shortened_time):
    #     print("time and freq trace worked out.")
    # else:
    #     print("time and freq trace were different length. time- freq:" + str(len(shortened_time)-len(mean_freq)))

    return shortened_time, mean_freq


def mean_freq_of_spiketimes_after_time_x(spiketimes, time_x, time_in_ms=False):
    """ Calculates the mean frequency of the portion of spiketimes that is after last_x_time """

    spiketimes = np.array(spiketimes)
    if len(spiketimes) <= 1:
        return 0

    relevant_spikes = spiketimes[spiketimes > time_x]

    if len(relevant_spikes) <= 1:
        return 0

    return calculate_mean_isi_freq(relevant_spikes, time_in_ms)


def calculate_mean_isi_freq(spiketimes, time_in_ms=False):
    if len(spiketimes) < 2:
        return 0

    isis = np.diff(spiketimes)
    if time_in_ms:
        isis = isis / 1000
    freqs = 1 / isis
    weights = isis / np.min(isis)

    return sum(freqs * weights) / sum(weights)


# @jit(nopython=True)  # only faster at around 30 000 calls
def calculate_coefficient_of_variation(spiketimes: np.ndarray) -> float:
    # CV (stddev of ISI divided by mean ISI (np.diff(spiketimes))
    if len(spiketimes) <= 2:
        return 0

    isi = np.diff(spiketimes)
    std = np.std(isi)
    mean = np.mean(isi)

    return std/mean


# @jit(nopython=True)  # maybe faster with more than ~60 000 calls
def calculate_serial_correlation(spiketimes: np.ndarray, max_lag: int) -> np.ndarray:
    isi = np.diff(spiketimes)
    if len(spiketimes) < max_lag + 1 or len(spiketimes) < 20:
        warn("Cannot compute serial correlation with list shorter than max lag...")
        return np.zeros(max_lag)
        # raise ValueError("Given list to short, with given max_lag")

    cor = np.zeros(max_lag)
    for lag in range(max_lag):
        lag = lag + 1
        first = isi[:-lag]
        second = isi[lag:]

        cor[lag-1] = np.corrcoef(first, second)[0][1]

    return cor


def calculate_eod_frequency(eod, sampling_interval):
    # TODO for few samples very volatile measure!
    std = np.std(eod)
    peaks, _ = detect_peaks(eod, std*1)
    peak_times = [p*sampling_interval for p in peaks]

    durations = np.diff(peak_times)
    mean_duration = np.mean(durations)

    return 1/mean_duration


def calculate_vector_strength_from_spiketimes(time, eod, spiketimes, sampling_interval):
    spiketime_indices = np.array(np.around((np.array(spiketimes) + time[0]) / sampling_interval), dtype=int)
    rel_spikes, eod_durs = eods_around_spikes(time, eod, spiketime_indices)

    return __vector_strength__(rel_spikes, eod_durs)


def detect_spike_indices_automatic_split(v1, threshold, min_length=5000, split_step=1000):
    split_start = 0
    step_size = split_step
    break_threshold = 0.25
    splits = []

    if len(v1) < min_length:
        splits = [(0, len(v1))]
    else:
        last_max = max(v1[0:min_length])
        last_min = min(v1[0:min_length])
        idx = min_length

        while idx < len(v1):
            if idx + step_size > len(v1):
                splits.append((split_start, len(v1)))
                break

            # max_dif = abs((max(v1[idx:idx+step_size]) / last_max) - 1)
            # min_dif = abs((min(v1[idx:idx+step_size]) / last_min) - 1)
            # print("last_max: {:.2f}, current_max: {:.2f}".format(last_max, max(v1[idx:idx+step_size])))
            # print("max_dif: {:.2f}, min_dif: {:.2f}".format(max_dif, min_dif))

            max_similar = abs((max(v1[idx:idx+step_size]) - last_max) / last_max) < break_threshold
            min_similar = abs((min(v1[idx:idx+step_size]) - last_min) / last_min) < break_threshold

            if not max_similar or not min_similar:
                # print("new split")
                end_idx = np.argmin(v1[idx-20:idx+21]) - 20
                splits.append((split_start, idx+end_idx))
                split_start = idx+end_idx
                last_max = max(v1[split_start:split_start + min_length])
                last_min = min(v1[split_start:split_start + min_length])
                idx = split_start + min_length
                continue
            else:
                pass
                # print("elongated!")

            idx += step_size

    if splits[-1][1] != len(v1):
        splits.append((split_start, len(v1)))

    # plt.plot(v1)

    # for s in splits:
    #     plt.plot(s, (max(v1[s[0]:s[1]]), max(v1[s[0]:s[1]])))

    all_peaks = []
    for s in splits:
        first_index = s[0]
        last_index = s[1]
        std = np.std(v1[first_index:last_index])
        peaks, _ = detect_peaks(v1[first_index:last_index], std * threshold)
        peaks = peaks + first_index
        # plt.plot(peaks, [np.max(v1[first_index:last_index]) for _ in peaks], 'o')
        all_peaks.extend(peaks)

    # plt.show()
    # plt.close()
    # all_peaks = np.array(all_peaks)

    return all_peaks


def detect_spiketimes(time, v1, threshold=2.0, min_length=5000, split_step=1000):
    all_peak_indicies = detect_spike_indices_automatic_split(v1, threshold=threshold, min_length=min_length, split_step=split_step)

    return [time[p_idx] for p_idx in all_peak_indicies]


def eods_around_spikes(time, eod, spiketime_idices):
    eod_durations = []
    relative_spike_times = []

    sign_changes = np.sign(eod[:-1]) != np.sign(eod[1:])
    eod_trace_increasing = eod[:-1] < eod[1:]

    eod_zero_crossings_indices = np.where(sign_changes & eod_trace_increasing)[0]
    for spike_idx in spiketime_idices:
        # test if it is inside two detected crossings
        if eod_zero_crossings_indices[0] > spike_idx > eod_zero_crossings_indices[-1]:
            continue
        zero_crossing_index_of_eod_end = np.argmax(eod_zero_crossings_indices > spike_idx)
        end_time_idx = eod_zero_crossings_indices[zero_crossing_index_of_eod_end]
        start_time_idx = eod_zero_crossings_indices[zero_crossing_index_of_eod_end - 1]

        eod_durations.append(time[end_time_idx] - time[start_time_idx])
        relative_spike_times.append(time[spike_idx] - time[start_time_idx])

        # try:
        #     start_time, end_time = search_eod_start_and_end_times(time, eod, spike_idx)
        #
        #     eod_durations.append(end_time-start_time)
        #     spiketime = time[spike_idx]
        #     relative_spike_times.append(spiketime - start_time)
        # except IndexError as e:
        #     continue

    return np.array(relative_spike_times), np.array(eod_durations)


# def search_eod_start_and_end_times(time, eod, index):
#     # TODO might break if a spike is in the cut off first or last eod!
#
#     # search start_time:
#     previous = index
#     working_idx = index-1
#     while True:
#         if eod[working_idx] < 0 < eod[previous]:
#             first_value = eod[working_idx]
#             second_value = eod[previous]
#
#             dif = second_value - first_value
#             part = np.abs(first_value/dif)
#
#             time_dif = np.abs(time[previous] - time[working_idx])
#             start_time = time[working_idx] + time_dif*part
#
#             break
#
#         previous = working_idx
#         working_idx -= 1
#
#     # search end_time
#     previous = index
#     working_idx = index + 1
#     while True:
#         if eod[previous] < 0 < eod[working_idx]:
#             first_value = eod[previous]
#             second_value = eod[working_idx]
#
#             dif = second_value - first_value
#             part = np.abs(first_value / dif)
#
#             time_dif = np.abs(time[previous] - time[working_idx])
#             end_time = time[working_idx] + time_dif * part
#
#             break
#
#         previous = working_idx
#         working_idx += 1
#
#     return start_time, end_time


def __vector_strength__(relative_spike_times: np.ndarray, eod_durations: np.ndarray):
    # adapted from Ramona

    n = len(relative_spike_times)
    if n == 0:
        return -1

    phase_times = (relative_spike_times / eod_durations) * 2 * np.pi
    vs = np.sqrt((1 / n * np.sum(np.cos(phase_times))) ** 2 + (1 / n * np.sum(np.sin(phase_times))) ** 2)

    return vs


def detect_f_zero_in_frequency_trace(time, frequency, stimulus_start, sampling_interval, peak_buffer_percent=0.05, buffer=0.025):

    if time[0] + 2*buffer > stimulus_start:
        print("F_zero detection: Not enough frequency trace before start of the stimulus.")
        return 0

    freq_before = frequency[int(time[0]+buffer/sampling_interval):int((stimulus_start - time[0] - buffer) / sampling_interval)]

    min_before = min(freq_before)
    max_before = max(freq_before)
    mean_before = np.mean(freq_before)

    # time where the f-zero is searched in
    start_idx, end_idx = time_window_detect_f_zero(time[0], stimulus_start, sampling_interval, buffer)

    if start_idx < 0:
        raise ValueError("Time window to detect f_zero starts in an negative index!")

    min_during_start_of_stim = min(frequency[start_idx:end_idx])
    max_during_start_of_stim = max(frequency[start_idx:end_idx])

    if abs(mean_before-min_during_start_of_stim) > abs(max_during_start_of_stim-mean_before):
        f_zero = min_during_start_of_stim
    else:
        f_zero = max_during_start_of_stim
    f_zero_idx = (frequency[start_idx:end_idx].index(f_zero) + start_idx,)

    peak_buffer = (max_before - min_before) * peak_buffer_percent
    if min_before - peak_buffer <= f_zero <= max_before + peak_buffer:
        end_idx = start_idx + int((end_idx-start_idx)/2)
        f_zero = np.mean(frequency[start_idx:end_idx])
        f_zero_idx = (start_idx, end_idx)

    # import matplotlib.pyplot as plt
    # plt.plot(time, frequency)
    # plt.plot(time[start_idx:end_idx], [f_zero for i in range(end_idx-start_idx)])
    # plt.show()

    max_frequency = int(1/sampling_interval)
    int_f_zero = int(f_zero)
    if int_f_zero > max_frequency:
        raise AssertionError("Detection of f-zero went very wrong! frequency above 1/sampling_interval.")
    if int_f_zero > max(frequency):
        raise AssertionError("detected f_zero bigger than the highest peak in the frequency trace...")
    return f_zero, f_zero_idx


def time_window_detect_f_zero(time_start, stimulus_start, sampling_interval, buffer=0.025):
    stimulus_start = stimulus_start - time_start
    start_idx = int((stimulus_start - 0.5 * buffer) / sampling_interval)
    end_idx = int((stimulus_start + buffer) / sampling_interval)
    return start_idx, end_idx


def detect_f_infinity_in_freq_trace(time, frequency, stimulus_start, stimulus_duration, sampling_interval, length=0.1, buffer=0.025):
    start_idx, end_idx = time_window_detect_f_infinity(time[0], stimulus_start, stimulus_duration, sampling_interval, length, buffer)
    return np.mean(frequency[start_idx:end_idx]), (start_idx, end_idx)


def time_window_detect_f_infinity(time_start, stimulus_start, stimulus_duration, sampling_interval, length=0.1, buffer=0.025):
    stimulus_end_time = stimulus_start + stimulus_duration - time_start

    start_idx = int((stimulus_end_time - length - buffer) / sampling_interval)
    end_idx = int((stimulus_end_time - buffer) / sampling_interval)
    return start_idx, end_idx


def detect_f_baseline_in_freq_trace(time, frequency, stimulus_start, sampling_interval, buffer=0.025):
    start_idx, end_idx = time_window_detect_f_baseline(time[0], stimulus_start, sampling_interval, buffer)
    f_baseline = np.mean(frequency[start_idx:end_idx])

    return f_baseline, (start_idx, end_idx)


def time_window_detect_f_baseline(time_start, stimulus_start, sampling_interval, buffer=0.025):
    stim_start = stimulus_start - time_start

    if stim_start < 0.1:
        warn("FICurve:__calculate_f_baseline__(): Quite short delay at the start.")

    start_idx = int(buffer / sampling_interval)
    end_idx = int((stim_start - buffer) / sampling_interval)
    return start_idx, end_idx