import numpy as np
import matplotlib.pyplot as plt
from thunderfish.tracker_v2 import include_progress_bar
from IPython import embed
import scipy.stats as scp
from tqdm import tqdm

def estimate_error(a_error, f_error, t_error, a_error_distribution, f_error_distribution,
                   min_f_weight=0.4, max_f_weight=0.9, t_of_max_f_weight=5.):
    """
    Cost function estimating the error between two fish signals at two different times using realative frequency
    difference and relative signal amplitude difference on n electrodes (relative because the values are compared to a
    given distribution of these values). With increasing time difference between the signals the impact of frequency
    error increases and the influence of amplitude error decreases due to potential changes caused by fish movement.

    Parameters
    ----------
    a_error: float
        MSE of amplitude difference of two electric signals recorded with n electodes.
    f_error: float
        absolute frequency difference between two electric signals.
    t_error: float
        temporal difference between two measured signals in s.
    a_error_distribution: list
        distribution of possible MSE of the amplitudes between random data points in the dataset.
    f_error_distribution: list
        distribution of possible frequency differences between random data points in the dataset.
    min_f_weight: float
        minimum proportion of the frequency impact to the error value.
    max_f_weight: float
        maximum proportion of the frequency impact to the error value.
    t_of_max_f_weight: float
        error value between two electric signals at two time points.

    Returns
    -------
    float
        error value between two electric signals at two time points
    """
    def boltzmann(t, alpha= 0.25, beta = 0.0, x0 = 4, dx = 0.85):
        """
        Calulates a boltzmann function.

        Parameters
        ----------
        t: array
            time vector.
        alpha: float
            max value of the boltzmann function.
        beta: float
            min value of the boltzmann function.
        x0: float
            time where the turning point of the boltzmann function occurs.
        dx: float
            slope of the boltzman function.

        Returns
        -------
        array
            boltzmann function of the given time array base on the other parameters given.
        """

        boltz = (alpha - beta) / (1. + np.exp(- (t - x0 ) / dx)  ) + beta
        return boltz

    a_error_distribution = np.array(a_error_distribution)
    f_error_distribution = np.array(f_error_distribution)

    if t_error >= 2.:
        f_weight = max_f_weight
    else:
        f_weight = 1. * (max_f_weight - min_f_weight) / t_of_max_f_weight * t_error + min_f_weight
    a_weight = 1. - f_weight

    a_e = a_weight * len(a_error_distribution[a_error_distribution < a_error]) / len(a_error_distribution)
    f_e = f_weight * len(f_error_distribution[f_error_distribution < f_error]) / len(f_error_distribution)
    t_e = boltzmann(t_error)
    # t_e = 0.5 * (1. * t_error / max_t_error) ** (1. / 3)  # when weight is 0.1 I end up in an endless loop somewhere

    return a_e + f_e + t_e


def freq_tracking_v3(fund_v, idx_v, sign_v, times, freq_tolerance, n_channels, return_tmp_idenities=False,
                     ioi_fti=False, a_error_distribution=False, f_error_distribution=False, freq_lims=(400, 1200)):
    """
    Sorting algorithm which sorts fundamental EOD frequnecies detected in consecutive powespectra of single or
    multielectrode recordings using frequency difference and frequnency-power amplitude difference on the electodes.

    Signal tracking and identity assiginment is accomplished in four steps:
    1) Extracting possible frequency and amplitude difference distributions.
    2) Esitmate relative error between possible datapoint connections (relative amplitude and frequency error based on
    frequency and amplitude error distribution).
    3) For a data window covering the EOD frequencies detected 10 seconds before the accual datapoint to assigne
    identify temporal identities based on overall error between two datapoints from smalles to largest.
    4) Form tight connections between datapoints where one datapoint is in the timestep that is currently of interest.

    Repeat these steps until the end of the recording.
    The temporal identities are only updated when the timestep of current interest reaches the middle (5 sec.) of the
    temporal identities. This is because no tight connection shall be made without checking the temporal identities.
    The temnporal identities are used to check if the potential connection from the timestep of interest to a certain
    datsapoint is the possibly best or if a connection in the futur will be better. If a future connection is better
    the thight connection is not made.

    Parameters
    ----------
    fundamentals: 2d-arraylike / list
        list of arrays of fundemantal EOD frequnecies. For each timestep/powerspectrum contains one list with the
        respectivly detected fundamental EOD frequnecies.
    signatures: 3d-arraylike / list
        same as fundamentals but for each value in fundamentals contains a list of powers of the respective frequency
        detected of n electrodes used.
    times: array
        respective time vector.
    freq_tolerance: float
        maximum frequency difference between two datapoints to be connected in Hz.
    n_channels: int
        number of channels/electodes used in the analysis.,
    return_tmp_idenities: bool
        only returne temporal identities at a certain timestep. Dependent on ioi_fti and only used to check algorithm.
    ioi_fti: int
        Index Of Interest For Temporal Identities: respective index in fund_v to calculate the temporal identities for.
    a_error_distribution: array
        possible amplitude error distributions for the dataset.
    f_error_distribution: array
        possible frequency error distribution for the dataset.
    fig: mpl.figure
        figure to plot the tracking progress life.
    ax: mpl.axis
        axis to plot the tracking progress life.
    freq_lims: double
        minimum/maximum frequency to be tracked.

    Returns
    -------
    fund_v: array
        flattened fundamtantals array containing all detected EOD frequencies in the recording.
    ident_v: array
        respective assigned identites throughout the tracking progress.
    idx_v: array
        respective index vectro impliing the time of the detected frequency.
    sign_v: 2d-array
        for each fundamental frequency the power of this frequency on the used electodes.
    a_error_distribution: array
        possible amplitude error distributions for the dataset.
    f_error_distribution: array
        possible frequency error distribution for the dataset.
    idx_of_origin_v: array
        for each assigned identity the index of the datapoint on which basis the assignement was made.
    """
    def clean_up(fund_v, ident_v, idx_v, times):
        """
        deletes/replaces with np.nan those identities only consisting from little data points and thus are tracking
        artefacts. Identities get deleted when the proportion of the trace (slope, ratio of detected datapoints, etc.)
        does not fit a real fish.

        Parameters
        ----------
        fund_v: array
            flattened fundamtantals array containing all detected EOD frequencies in the recording.
        ident_v: array
            respective assigned identites throughout the tracking progress.
        idx_v: array
            respective index vectro impliing the time of the detected frequency.
        times: array
            respective time vector.

        Returns
        -------
        ident_v: array
            cleaned up identities vector.

        """
        # print('clean up')
        for ident in np.unique(ident_v[~np.isnan(ident_v)]):
            if np.median(np.abs(np.diff(fund_v[ident_v == ident]))) >= 0.25:
                ident_v[ident_v == ident] = np.nan
                continue

            if len(ident_v[ident_v == ident]) <= 10:
                ident_v[ident_v == ident] = np.nan
                continue

        return ident_v

    def get_a_and_f_error_dist(fund_v, idx_v, sign_v, ):
        """
        get the distribution of possible frequency and amplitude errors for the tracking.

        Parameters
        ----------
        fund_v: array
            flattened fundamtantals array containing all detected EOD frequencies in the recording.
        idx_v: array
            respective index vectro impliing the time of the detected frequency.
        sign_v: 2d-array
            for each fundamental frequency the power of this frequency on the used electodes.

        Returns
        -------
        f_error_distribution: array
            possible frequency error distribution for the dataset.
        a_error_distribution: array
            possible amplitude error distributions for the dataset.
        """
        # get f and amp signature distribution ############### BOOT #######################
        a_error_distribution = np.zeros(20000)  # distribution of amplitude errors
        f_error_distribution = np.zeros(20000)  # distribution of frequency errors
        idx_of_distribution = np.zeros(20000)  # corresponding indices

        b = 0  # loop varialble
        next_message = 0.  # feedback

        while b < 20000:
            next_message = include_progress_bar(b, 20000, 'get f and sign dist', next_message)  # feedback

            while True:  # finding compare indices to create initial amp and freq distribution
                # r_idx0 = np.random.randint(np.max(idx_v[~np.isnan(idx_v)]))
                r_idx0 = np.random.randint(np.max(idx_v[~np.isnan(idx_v)]))
                r_idx1 = r_idx0 + 1
                if len(sign_v[idx_v == r_idx0]) != 0 and len(sign_v[idx_v == r_idx1]) != 0:
                    break

            r_idx00 = np.random.randint(len(sign_v[idx_v == r_idx0]))
            r_idx11 = np.random.randint(len(sign_v[idx_v == r_idx1]))

            s0 = sign_v[idx_v == r_idx0][r_idx00]  # amplitude signatures
            s1 = sign_v[idx_v == r_idx1][r_idx11]

            f0 = fund_v[idx_v == r_idx0][r_idx00]  # fundamentals
            f1 = fund_v[idx_v == r_idx1][r_idx11]

            # if np.abs(f0 - f1) > freq_tolerance:  # frequency threshold
            if np.abs(f0 - f1) > 10.:  # frequency threshold
                continue

            idx_of_distribution[b] = r_idx0
            a_error_distribution[b] = np.sqrt(np.sum([(s0[k] - s1[k]) ** 2 for k in range(len(s0))]))
            f_error_distribution[b] = np.abs(f0 - f1)
            b += 1

        return f_error_distribution, a_error_distribution

    def get_tmp_identities(i0_m, i1_m, error_cube, fund_v, idx_v, i, ioi_fti, dps, idx_comp_range):
        """
        extract temporal identities for a datasnippted of 2*index compare range of the original tracking algorithm.
        for each data point in the data window finds the best connection within index compare range and, thus connects
        the datapoints based on their minimal error value until no connections are left or possible anymore.

        Parameters
        ----------
        i0_m: 2d-array
            for consecutive timestamps contains for each the indices of the origin EOD frequencies.
        i1_m: 2d-array
            respectively contains the indices of the targen EOD frequencies, laying within index compare range.
        error_cube: 3d-array
            error values for each combination from i0_m and the respective indices in i1_m.
        fund_v: array
            flattened fundamtantals array containing all detected EOD frequencies in the recording.
        idx_v: array
            respective index vectro impliing the time of the detected frequency.
        i: int
            loop variable and current index of interest for the assignment of tight connections.
        ioi_fti: int
            index of interest for temporal identities.
        dps: float
            detections per second. 1. / 'temporal resolution of the tracking'
        idx_comp_range: int
            index compare range for the assignment of two data points to each other.

        Returns
        -------
        tmp_ident_v: array
            for each EOD frequencies within the index compare range for the current time step of interest contains the
            temporal identity.
        errors_to_v: array
            for each assigned temporal identity contains the error value based on which this connection was made.

        """
        next_tmp_identity = 0
        # mask_cube = [np.ones(np.shape(error_cube[n]), dtype=bool) for n in range(len(error_cube))]

        max_shape = np.max([np.shape(layer) for layer in error_cube[1:]], axis=0)
        cp_error_cube = np.full((len(error_cube)-1, max_shape[0], max_shape[1]), np.nan)
        for enu, layer in enumerate(error_cube[1:]):
            cp_error_cube[enu, :np.shape(error_cube[enu+1])[0], :np.shape(error_cube[enu+1])[1]] = layer

        try:
            tmp_ident_v = np.full(len(fund_v), np.nan)
            errors_to_v = np.full(len(fund_v), np.nan)
        except:
            tmp_ident_v = np.zeros(len(fund_v)) / 0.
            errors_to_v = np.zeros(len(fund_v)) / 0.

        layers, idx0s, idx1s = np.unravel_index(np.argsort(cp_error_cube, axis=None), np.shape(cp_error_cube))

        layers = layers+1
        # embed()
        # quit()


        for layer, idx0, idx1 in zip(layers, idx0s, idx1s):
            # embed()
            # quit()
            if np.isnan(cp_error_cube[layer-1, idx0, idx1]):
                break

            # _____ some control functions _____ ###
            if not ioi_fti:
                if idx_v[i1_m[layer][idx1]] - i > idx_comp_range*2:
                    continue
            else:
                if idx_v[i1_m[layer][idx1]] - idx_v[ioi_fti] > idx_comp_range*2:
                    continue

            if fund_v[i0_m[layer][idx0]] > fund_v[i1_m[layer][idx1]]:
                if 1. * np.abs(fund_v[i0_m[layer][idx0]] - fund_v[i1_m[layer][idx1]]) / ((idx_v[i1_m[layer][idx1]] - idx_v[i0_m[layer][idx0]]) / dps) > 2.:
                    continue
            else:
                if 1. * np.abs(fund_v[i0_m[layer][idx0]] - fund_v[i1_m[layer][idx1]]) / ((idx_v[i1_m[layer][idx1]] - idx_v[i0_m[layer][idx0]]) / dps) > 2.:
                    continue

            if np.isnan(tmp_ident_v[i0_m[layer][idx0]]):
                if np.isnan(tmp_ident_v[i1_m[layer][idx1]]):
                    tmp_ident_v[i0_m[layer][idx0]] = next_tmp_identity
                    tmp_ident_v[i1_m[layer][idx1]] = next_tmp_identity
                    errors_to_v[i1_m[layer][idx1]] = cp_error_cube[layer-1][idx0, idx1]
                    # errors_to_v[i0_m[layer][idx0]] = error_cube[layer][idx0, idx1]

                    # errors_to_v[(tmp_ident_v == tmp_ident_v[i1_m[layer][idx1]]) & (np.isnan(errors_to_v))] = error_cube[layer][idx0, idx1]
                    next_tmp_identity += 1
                else:
                    if idx_v[i0_m[layer][idx0]] in idx_v[tmp_ident_v == tmp_ident_v[i1_m[layer][idx1]]]:
                        continue
                    tmp_ident_v[i0_m[layer][idx0]] = tmp_ident_v[i1_m[layer][idx1]]
                    errors_to_v[i1_m[layer][idx1]] = cp_error_cube[layer-1][idx0, idx1]

                    # errors_to_v[(tmp_ident_v == tmp_ident_v[i1_m[layer][idx1]]) & (np.isnan(errors_to_v))] = error_cube[layer][idx0, idx1]
                    # errors_to_v[tmp_ident_v == tmp_ident_v[i1_m[layer][idx1]]][0] = np.nan

            else:
                if np.isnan(tmp_ident_v[i1_m[layer][idx1]]):
                    if idx_v[i1_m[layer][idx1]] in idx_v[tmp_ident_v == tmp_ident_v[i0_m[layer][idx0]]]:
                        continue
                    tmp_ident_v[i1_m[layer][idx1]] = tmp_ident_v[i0_m[layer][idx0]]
                    errors_to_v[i1_m[layer][idx1]] = cp_error_cube[layer-1][idx0, idx1]

                    # errors_to_v[(tmp_ident_v == tmp_ident_v[i1_m[layer][idx1]]) & (np.isnan(errors_to_v))] = error_cube[layer][idx0, idx1]
                    # errors_to_v[tmp_ident_v == tmp_ident_v[i1_m[layer][idx1]]][0] = np.nan

                else:
                    if tmp_ident_v[i0_m[layer][idx0]] == tmp_ident_v[i1_m[layer][idx1]]:
                        if np.isnan(errors_to_v[i1_m[layer][idx1]]):
                            errors_to_v[i1_m[layer][idx1]] = cp_error_cube[layer-1][idx0, idx1]
                        continue

                    idxs_i0 = idx_v[tmp_ident_v == tmp_ident_v[i0_m[layer][idx0]]]
                    idxs_i1 = idx_v[tmp_ident_v == tmp_ident_v[i1_m[layer][idx1]]]

                    if np.any(np.diff(sorted(np.concatenate((idxs_i0, idxs_i1)))) == 0):
                        continue
                    tmp_ident_v[tmp_ident_v == tmp_ident_v[i0_m[layer][idx0]]] = tmp_ident_v[i1_m[layer][idx1]]

                    if np.isnan(errors_to_v[i1_m[layer][idx1]]):
                        errors_to_v[i1_m[layer][idx1]] = cp_error_cube[layer-1][idx0, idx1]

        return tmp_ident_v, errors_to_v

    # _____ parameters and vectors _____ ###

    detection_time_diff = times[1] - times[0]
    dps = 1. / detection_time_diff

    try:
        ident_v = np.full(len(fund_v), np.nan)  # fish identities of frequencies
        idx_of_origin_v = np.full(len(fund_v), np.nan)
    except:
        ident_v = np.zeros(len(fund_v)) / 0.  # fish identities of frequencies
        idx_of_origin_v = np.zeros(len(fund_v)) / 0.

    idx_comp_range = int(np.floor(dps * 5.))  # maximum compare range backwards for amplitude signature comparison
    # idx_comp_range = int(np.floor(dps * 15.))  # maximum compare range backwards for amplitude signature comparison
    low_freq_th = 400.  # min. frequency tracked
    high_freq_th = 1050.  # max. frequency tracked

    # _____ get amp and freq error distribution
    # if hasattr(a_error_distribution, '__len__') and hasattr(f_error_distribution, '__len__'):
    #     pass
    # else:
    #     f_error_distribution, a_error_distribution = get_a_and_f_error_dist(fund_v, idx_v, sign_v)

    # _____ create initial error cube _____ ###
    error_cube = []  # [fundamental_list_idx, freqs_to_assign, target_freqs]
    i0_m = []
    i1_m = []

    next_message = 0.
    start_idx = idx_v[0]

    f_error_distribution = []
    a_error_distribution = []
    idx_of_error = []

    fig1, ax1 = plt.subplots()
    ax1.plot(times[idx_v], fund_v, 'o', color='grey', alpha=.1)
    tmp_ident_handle = []
    ident_handles = []

    plt.show(block=False)


    # # start_idx = 0 if not ioi_fti else idx_v[ioi_fti] # Index Of Interest for temporal identities
    # distributions NEW
    for i in range(start_idx, int(start_idx + idx_comp_range*2)):
        next_message = include_progress_bar(i - start_idx, int(idx_comp_range*2), 'error dist init', next_message)
        i0_v = np.arange(len(idx_v))[(idx_v == i) & (fund_v >= freq_lims[0]) & (fund_v <= freq_lims[1])]  # indices of fundamtenals to assign
        i1_v = np.arange(len(idx_v))[(idx_v > i) & (idx_v <= (i + int(idx_comp_range))) & (fund_v >= freq_lims[0]) & (fund_v <= freq_lims[1])]  # indices of possible targets

        # i0_m.append(i0_v)
        # i1_m.append(i1_v)
        if len(i0_v) == 0 or len(i1_v) == 0:  # if nothing to assign or no targets continue
            continue

        for enu0 in range(len(fund_v[i0_v])):
            if fund_v[i0_v[enu0]] < low_freq_th or fund_v[i0_v[enu0]] > high_freq_th:
                continue
            for enu1 in range(len(fund_v[i1_v])):
                if fund_v[i1_v[enu1]] < low_freq_th or fund_v[i1_v[enu1]] > high_freq_th:
                    continue
                if np.abs(fund_v[i0_v[enu0]] - fund_v[i1_v[enu1]]) >= freq_tolerance:  # freq difference to high
                    continue
                a_error_distribution.append(np.sqrt(np.sum([(sign_v[i0_v[enu0]][k] - sign_v[i1_v[enu1]][k]) ** 2 for k in range(len(sign_v[i0_v[enu0]]))])))
                f_error_distribution.append(np.abs(fund_v[i0_v[enu0]] - fund_v[i1_v[enu1]]))
                idx_of_error.append(idx_v[i1_v[enu1]])

    # Initial error cube
    for i in range(start_idx, int(start_idx + idx_comp_range*2)):
        next_message = include_progress_bar(i - start_idx, int(idx_comp_range*2), 'initial error cube', next_message)
        i0_v = np.arange(len(idx_v))[(idx_v == i) & (fund_v >= freq_lims[0]) & (fund_v <= freq_lims[1])]  # indices of fundamtenals to assign
        i1_v = np.arange(len(idx_v))[(idx_v > i) & (idx_v <= (i + int(idx_comp_range))) & (fund_v >= freq_lims[0]) & (fund_v <= freq_lims[1])]  # indices of possible targets

        i0_m.append(i0_v)
        i1_m.append(i1_v)

        if len(i0_v) == 0 or len(i1_v) == 0:  # if nothing to assign or no targets continue
            error_cube.append(np.array([[]]))
            continue
        try:
            error_matrix = np.full((len(i0_v), len(i1_v)), np.nan)
        except:
            error_matrix = np.zeros((len(i0_v), len(i1_v))) / 0.

        for enu0 in range(len(fund_v[i0_v])):
            if fund_v[i0_v[enu0]] < low_freq_th or fund_v[i0_v[enu0]] > high_freq_th:  # freq to assigne out of tracking range
                continue
            for enu1 in range(len(fund_v[i1_v])):
                if fund_v[i1_v[enu1]] < low_freq_th or fund_v[i1_v[enu1]] > high_freq_th:  # target freq out of tracking range
                    continue
                if np.abs(fund_v[i0_v[enu0]] - fund_v[i1_v[enu1]]) >= freq_tolerance:  # freq difference to high
                    continue

                a_error = np.sqrt(np.sum([(sign_v[i0_v[enu0]][j] - sign_v[i1_v[enu1]][j]) ** 2 for j in range(n_channels)]))
                f_error = np.abs(fund_v[i0_v[enu0]] - fund_v[i1_v[enu1]])
                t_error = 1. * np.abs(idx_v[i0_v[enu0]] - idx_v[i1_v[enu1]]) / dps

                error_matrix[enu0, enu1] = estimate_error(a_error, f_error, t_error, a_error_distribution, f_error_distribution)
        error_cube.append(error_matrix)


    cube_app_idx = idx_v[0] + len(error_cube)

    # _____ accual tracking _____ ###
    next_identity = 0
    next_message = 0.00
    # for enu, i in enumerate(np.arange(len(fundamentals))):
    for enu, i in enumerate(np.unique(idx_v)):
        if enu == 0:
            fig, ax = plt.subplots(1, 2)
            n, h = np.histogram(f_error_distribution, 1000)
            # a0, = ax[0].plot(h[1:], np.cumsum(n) / np.sum(n), color='cornflowerblue', linewidth=2)
            ax[0].plot(h[:-1], np.cumsum(n) / np.sum(n), color='k', linewidth=2)
            n, h = np.histogram(a_error_distribution, 1000)
            # a1, = ax[1].plot(h[1:], np.cumsum(n) / np.sum(n), color='cornflowerblue', linewidth=2)
            ax[1].plot(h[:-1], np.cumsum(n) / np.sum(n), color='k', linewidth=2)
            plt.show(block=False)
        else:
            if len(f_error_distribution) >= 5000:
                f_error_distribution = f_error_distribution[-4000:]
                a_error_distribution = a_error_distribution[-4000:]
                idx_of_error = idx_of_error[-4000:]
                n, h = np.histogram(f_error_distribution, 1000)
                # a0.set_data(h[1:], np.cumsum(n) / np.sum(n))
                c = np.random.rand(3)
                ax[0].plot(h[:-1], np.cumsum(n) / np.sum(n), color=c)
                n, h = np.histogram(a_error_distribution, 1000)
                ax[1].plot(h[:-1], np.cumsum(n) / np.sum(n), color=c)
                # a1.set_data(h[1:], np.cumsum(n) / np.sum(n))
                fig.canvas.draw()

        # print(i, idx_v[-1])
        if i != 0 and (i % int(idx_comp_range * 120)) == 0: # clean up every 10 minutes
            ident_v = clean_up(fund_v, ident_v, idx_v, times)

        if not return_tmp_idenities:
            next_message = include_progress_bar(enu, len(np.unique(idx_v)), 'tracking', next_message)  # feedback

        if enu % idx_comp_range == 0:
            # t0 = time.time()
            tmp_ident_v, errors_to_v = get_tmp_identities(i0_m, i1_m, error_cube, fund_v, idx_v, i, ioi_fti, dps, idx_comp_range)
            tmp_ident_regress = {}
            for ident in np.unique(tmp_ident_v[~np.isnan(tmp_ident_v)]):
                slope, intercept, _, _, _ = scp.linregress(idx_v[tmp_ident_v == ident], fund_v[tmp_ident_v == ident])
                tmp_ident_regress.update({'%.0f' %ident: [slope, intercept]})
                # embed()
                # quit()

            # for hand in tmp_ident_handle:
            #     hand.remove()
            # tmp_ident_handle = []
            # for ident in np.unique(tmp_ident_v[~np.isnan(tmp_ident_v)]):
            #     handle, = ax1.plot(times[idx_v][tmp_ident_v == ident], fund_v[tmp_ident_v == ident], color='red')
            #     tmp_ident_handle.append(handle)
            # fig1.canvas.draw()


        idx0s, idx1s = np.unravel_index(np.argsort(error_cube[0], axis=None), np.shape(error_cube[0]))

        for idx0, idx1 in zip(idx0s, idx1s):
            if np.isnan(error_cube[0][idx0, idx1]):
                break

            if freq_lims:
                if fund_v[i0_m[0][idx0]] > freq_lims[1] or fund_v[i0_m[0][idx0]] < freq_lims[0]:
                    continue
                if fund_v[i1_m[0][idx1]] > freq_lims[1] or fund_v[i1_m[0][idx1]] < freq_lims[0]:
                    continue

            if not np.isnan(ident_v[i1_m[0][idx1]]):
                continue

            if not np.isnan(errors_to_v[i1_m[0][idx1]]):
                if errors_to_v[i1_m[0][idx1]] < error_cube[0][idx0, idx1]:
                    continue

            if np.isnan(ident_v[i0_m[0][idx0]]):  # i0 doesnt have identity
                if 1. * np.abs(fund_v[i0_m[0][idx0]] - fund_v[i1_m[0][idx1]]) / ((idx_v[i1_m[0][idx1]] - idx_v[i0_m[0][idx0]]) / dps) > 2.:
                    continue

                if np.isnan(ident_v[i1_m[0][idx1]]):  # i1 doesnt have identity
                    ident_v[i0_m[0][idx0]] = next_identity
                    ident_v[i1_m[0][idx1]] = next_identity
                    next_identity += 1
                    a_error_distribution.append(np.sqrt(np.sum([(sign_v[i0_m[0][idx0]][k] - sign_v[i1_m[0][idx1]][k]) ** 2 for k in range(len(sign_v[i0_m[0][idx0]]))])))
                    f_error_distribution.append(np.abs(fund_v[i0_m[0][idx0]] - fund_v[i1_m[0][idx1]]))
                    idx_of_error.append(idx_v[i1_m[0][idx1]])

                else:  # i1 does have identity
                    continue

            else:  # i0 does have identity
                if np.isnan(ident_v[i1_m[0][idx1]]):  # i1 doesnt have identity
                    if idx_v[i1_m[0][idx1]] in idx_v[ident_v == ident_v[i0_m[0][idx0]]]:
                        continue
                    # _____ if either idx0-idx1 is not a direct connection or ...
                    # _____ idx1 is not the new last point of ident[idx0] check ...
                    if not idx_v[i0_m[0][idx0]] == idx_v[ident_v == ident_v[i0_m[0][idx0]]][-1]:  # if i0 is not the last ...
                        if len(ident_v[(ident_v == ident_v[i0_m[0][idx0]]) & (idx_v > idx_v[i0_m[0][idx0]]) & (idx_v < idx_v[i1_m[0][idx1]])]) == 0:  # zwischen i0 und i1 keiner
                            next_idx_after_new = np.arange(len(ident_v))[(ident_v == ident_v[i0_m[0][idx0]]) & (idx_v > idx_v[i1_m[0][idx1]])][0]
                            if tmp_ident_v[next_idx_after_new] != tmp_ident_v[i1_m[0][idx1]]:
                                continue
                        elif len(ident_v[(ident_v == ident_v[i0_m[0][idx0]]) & (idx_v > idx_v[i1_m[0][idx1]])]) == 0:  # keiner nach i1
                            last_idx_before_new = np.arange(len(ident_v))[(ident_v == ident_v[i0_m[0][idx0]]) & (idx_v < idx_v[i1_m[0][idx1]])][-1]
                            if tmp_ident_v[last_idx_before_new] != tmp_ident_v[i1_m[0][idx1]]:
                                continue
                        else:  # sowohl als auch
                            next_idx_after_new = np.arange(len(ident_v))[(ident_v == ident_v[i0_m[0][idx0]]) & (idx_v > idx_v[i1_m[0][idx1]])][0]
                            last_idx_before_new = np.arange(len(ident_v))[(ident_v == ident_v[i0_m[0][idx0]]) & (idx_v < idx_v[i1_m[0][idx1]])][-1]
                            if tmp_ident_v[last_idx_before_new] != tmp_ident_v[i1_m[0][idx1]] or tmp_ident_v[next_idx_after_new] != tmp_ident_v[i1_m[0][idx1]]:
                                continue

                    ident_v[i1_m[0][idx1]] = ident_v[i0_m[0][idx0]]
                    a_error_distribution.append(np.sqrt(np.sum([(sign_v[i0_m[0][idx0]][k] - sign_v[i1_m[0][idx1]][k]) ** 2 for k in range(len(sign_v[i0_m[0][idx0]]))])))
                    f_error_distribution.append(np.abs(fund_v[i0_m[0][idx0]] - fund_v[i1_m[0][idx1]]))
                    idx_of_error.append(idx_v[i1_m[0][idx1]])
                else:
                    continue

            idx_of_origin_v[i1_m[0][idx1]] = i0_m[0][idx0]
        #
        # for hand in ident_handles:
        #     hand.remove()
        # ident_handles = []
        # for ident in np.unique(ident_v[~np.isnan(ident_v)]):
        #     hand, = ax1.plot(times[idx_v][ident_v == ident], fund_v[ident_v == ident], color='green', lw=2)
        #     ident_handles.append(hand)
        #
        # ax1.set_xlim([times[idx_v[~np.isnan(tmp_ident_v)][0]] - 10, times[idx_v[~np.isnan(tmp_ident_v)][-1]] + 10 ])
        # fig1.canvas.draw()

        # plt.waitforbuttonpress()

        # sort_time += time.time()-t0
        i0_m.pop(0)
        i1_m.pop(0)
        error_cube.pop(0)

        i0_v = np.arange(len(idx_v))[(idx_v == cube_app_idx) & (fund_v >= freq_lims[0]) & (fund_v <= freq_lims[1])]  # indices of fundamtenals to assign
        i1_v = np.arange(len(idx_v))[(idx_v > cube_app_idx) & (idx_v <= (cube_app_idx + idx_comp_range)) & (fund_v >= freq_lims[0]) & (fund_v <= freq_lims[1])]  # indices of possible targets

        i0_m.append(i0_v)
        i1_m.append(i1_v)

        # embed()
        # quit()
        if len(i0_v) == 0 or len(i1_v) == 0:  # if nothing to assign or no targets continue
            error_cube.append(np.array([[]]))

        else:
            try:
                error_matrix = np.full((len(i0_v), len(i1_v)), np.nan)
            except:
                error_matrix = np.zeros((len(i0_v), len(i1_v))) / 0.

            for enu0 in range(len(fund_v[i0_v])):
                if fund_v[i0_v[enu0]] < low_freq_th or fund_v[i0_v[enu0]] > high_freq_th:  # freq to assigne out of tracking range
                    continue

                for enu1 in range(len(fund_v[i1_v])):
                    if fund_v[i1_v[enu1]] < low_freq_th or fund_v[i1_v[enu1]] > high_freq_th:  # target freq out of tracking range
                        continue
                    if np.abs(fund_v[i0_v[enu0]] - fund_v[i1_v[enu1]]) >= freq_tolerance:  # freq difference to high
                        continue

                    a_error = np.sqrt(
                        np.sum([(sign_v[i0_v[enu0]][j] - sign_v[i1_v[enu1]][j]) ** 2 for j in range(n_channels)]))
                    if not np.isnan(tmp_ident_v[i0_v[enu0]]):
                        a = tmp_ident_regress['%.0f' % tmp_ident_v[i0_v[enu0]]][0]
                        b = tmp_ident_regress['%.0f' % tmp_ident_v[i0_v[enu0]]][1]
                        f_error = np.abs( (a*idx_v[i1_v[enu1]]+b) - fund_v[i1_v[enu1]])

                    else:
                        f_error = np.abs(fund_v[i0_v[enu0]] - fund_v[i1_v[enu1]])
                    t_error = 1. * np.abs(idx_v[i0_v[enu0]] - idx_v[i1_v[enu1]]) / dps

                    error_matrix[enu0, enu1] = estimate_error(a_error, f_error, t_error, a_error_distribution,
                                                              f_error_distribution)
            error_cube.append(error_matrix)

        cube_app_idx += 1
    ident_v = clean_up(fund_v, ident_v, idx_v, times)

    return fund_v, ident_v, idx_v, sign_v, a_error_distribution, f_error_distribution, idx_of_origin_v

if __name__ == '__main__':

    fund_v = np.load('/home/raab/paper_create/raab2018_tracking_without_tagging/bsp_data/2016-04-10-22:14/fund_v.npy')
    idx_v = np.load('/home/raab/paper_create/raab2018_tracking_without_tagging/bsp_data/2016-04-10-22:14/idx_v.npy')
    times = np.load('/home/raab/paper_create/raab2018_tracking_without_tagging/bsp_data/2016-04-10-22:14/times.npy')
    sign_v = np.load('/home/raab/paper_create/raab2018_tracking_without_tagging/bsp_data/2016-04-10-22:14/sign_v.npy')
    # ident_v = np.load('/home/raab/paper_create/raab2018_tracking_without_tagging/bsp_data/2016-04-10-22:14/ident_v.npy')
    ident_v = np.load('/home/raab/paper_create/raab2018_tracking_without_tagging/bsp_data/2016-04-10-22:14/ident_v2.npy')

    # embed()
    # quit()
    # fund_v, ident_v, idx_v, sign_v, a_error_distribution, f_error_distribution, idx_of_origin_v = freq_tracking_v3(fund_v, idx_v, sign_v, times, freq_tolerance=40., n_channels=np.shape(sign_v)[1])

    a_errors_true = []
    a_errors_false = [] # 2 sek
    f_errors_true = [] # 5 sek
    f_errors_false = []

    # freq tollerance = 2

    for i in tqdm(range(len(fund_v))):
        for j in range(i, len(fund_v)):
            if idx_v[i] == idx_v[j] or times[idx_v[j]] - times[idx_v[i]] > .5:
                continue
            if np.abs(fund_v[i] - fund_v[j]) >2.:
                continue

            if ident_v[i] == ident_v[j]:
                f_errors_true.append(np.abs(fund_v[i] - fund_v[j]))
                a_errors_true.append(np.sqrt(np.sum([(sign_v[i][a] - sign_v[j][a]) ** 2 for a in range(len(sign_v[i]))])))
            else:
                f_errors_false.append(np.abs(fund_v[i] - fund_v[j]))
                a_errors_false.append(np.sqrt(np.sum([(sign_v[i][a] - sign_v[j][a]) ** 2 for a in range(len(sign_v[i]))])))



    fig, ax = plt.subplots()
    ax.plot(times[idx_v], fund_v, 'o', color='grey', alpha=.5)
    for ident in np.unique(ident_v[~np.isnan(ident_v)]):
        ax.plot(times[idx_v[ident_v == ident]], fund_v[ident_v == ident], color = np.random.rand(3), marker ='.')
    plt.show()
    embed()
    quit()