from parser.CellData import CellData
import pyrelacs.DataLoader as Dl
from thunderfish.eventdetection import detect_peaks

import os
import numpy as np
import matplotlib.pyplot as plt

TEST_SIMILARITY = True
REDETECT_SPIKES = True

TOP_PERCENTILE = 95
BOTTOM_PERCENTILE = 5
FACTOR = 0.5

# strange_cells:

# 2012-07-12-ap-invivo-1 # cell with a few traces with max similarity < 0.1
# 2012-12-13-af-invivo-1 # cell with MANY traces with max similarity < 0.1
# 2012-12-21-ak-invivo-1 # a few
# 2012-12-21-an-invivo-1 # a few
# 2013-02-21-ae-invivo-1 # "
# 2013-02-21-ag-invivo-1 # "

# 2014-06-06-ac-invivo-1 # alot below 0.4 but a good bit above the 2nd max


def main():
    test_fi_trace()
    quit()

    # find_and_save_best_threshold()
    # quit()

    directory = "data/final/"
    skip_to = False
    skip_to_cell = "2012-12-13-af-invivo-1"
    threshold_file_path = "data/fi_thresholds.tsv"
    thresholds_dict = load_fi_thresholds(threshold_file_path)

    for cell in sorted(os.listdir(directory)):
        # if cell != "2014-01-10-ab-invivo-1":
        #     continue
        if skip_to:
            if cell == skip_to_cell:
                skip_to = False
            else:
                continue

        cell_dir = directory + cell  # "data/final/2012-04-20-af-invivo-1/"
        print(cell_dir)

        cell_data = CellData(cell_dir)
        before = cell_data.get_delay()
        after = cell_data.get_after_stimulus_duration()
        # parser = DatParser(cell_dir)

        if os.path.exists(cell_dir + "/redetected_spikes.npy") and not REDETECT_SPIKES:
            spikes = np.load(cell_dir + "/redetected_spikes.npy", allow_pickle=True)
            traces = np.load(cell_dir + "/fi_time_v1_traces.npy", allow_pickle=True)
        else:
            step = cell_data.get_sampling_interval()
            threshold_pair = thresholds_dict[cell]
            spikes, traces = get_redetected_spikes(cell_dir, before, after, step, threshold_pair)
            np.save(cell_dir + "/redetected_spikes.npy", spikes, allow_pickle=True)
            np.save(cell_dir + "/fi_time_v1_traces.npy", traces, allow_pickle=True)
            print("redetection finished")

        if os.path.exists(cell_dir + "/fi_traces_contrasts.npy") and not TEST_SIMILARITY:
            trace_contrasts = np.load(cell_dir + "/fi_traces_contrasts.npy", allow_pickle=True)
            trace_max_similarity = np.load(cell_dir + "/fi_traces_contrasts_similarity.npy", allow_pickle=True)

        else:
            cell_spiketrains = cell_data.get_fi_spiketimes()

            # plt.plot(traces[0][0], traces[0][1])
            # plt.eventplot(cell_spiketrains[0][0], colors="black", lineoffsets=max(traces[0][1]) + 1)
            # plt.eventplot(spikes[0], colors="black", lineoffsets=max(traces[0][1]) + 2)
            # plt.show()
            # plt.close()

            # unsorted_cell_spiketimes = get_unsorted_spiketimes(cell_dir + "/fispikes1.dat")

            trace_contrasts = np.zeros(len(traces), dtype=np.int) - 1
            trace_max_similarity = np.zeros((len(traces), 2)) - 1
            for i, spiketrain in enumerate(spikes):

                similarity, max_idx, maxima = find_matching_spiketrain(spiketrain, cell_spiketrains, cell_data.get_sampling_interval())

                trace_contrasts[i] = max_idx[0]
                trace_max_similarity[i] = maxima
                # if trace_max_similarity[i] <= 0.05:
                #     step = cell_data.get_sampling_interval()
                #     test_detected_spiketimes(traces[i], spiketrain, cell_spiketrains[max_idx[0]], step)

            np.save(cell_dir + "/fi_traces_contrasts.npy", trace_contrasts, allow_pickle=True)
            np.save(cell_dir + "/fi_traces_contrasts_similarity.npy", trace_max_similarity, allow_pickle=True)
            print("similarity test finished")

        # step_size = cell_data.get_sampling_interval()
        # steps = np.arange(0, 100.1, 0.5)
        # percentiles_arr = np.zeros((len(traces), len(steps)))
        # for i, trace_pair in enumerate(traces):
        #     v1_part = trace_pair[1][-int(np.rint(0.6/step_size)):]
        #     percentiles = np.percentile(np.array(v1_part) - np.median(v1_part), steps)
        #     percentiles_arr[i, :] = percentiles
        #     plt.plot(steps, percentiles)

        # mean_perc = np.mean(percentiles_arr, axis=0)

        # plt.plot(steps, mean_perc)
        # plt.show()
        # plt.close()
        # bins = np.arange(0, 1.001, 0.05)
        # plt.hist(trace_max_similarity, bins=bins)
        # plt.show()
        # plt.close()
        #
        #
        # step_size = cell_data.get_sampling_interval()
        # cell_spiketrains = cell_data.get_fi_spiketimes()
        # contrasts = cell_data.get_fi_contrasts()
        # tested_contrasts = []
        # for i, redetected in enumerate(spikes):
        #     idx = trace_contrasts[i]
        #     if idx not in tested_contrasts:
        #         print("Contrast: {:.3f}".format(contrasts[idx]))
        #         test_detected_spiketimes(traces[i], redetected, cell_spiketrains[idx], step_size)
        #         tested_contrasts.append(idx)


def test_fi_trace():
    # cell = "2012-12-13-af-invivo-1"
    # cell = "2012-07-12-ap-invivo-1"
    data_dir = "data/final/"

    full_count = 0
    contrast_trials_below_three = 0
    differences_max_second_max = []
    for cell in sorted(os.listdir(data_dir)):
        cell_dir = data_dir + cell
        # print(cell)
        cell_data = CellData(cell_dir)
        step_size = cell_data.get_sampling_interval()
        spiketimes = cell_data.get_fi_spiketimes()
        # trials = [len(x) for x in spiketimes]
        # total = sum(trials)
        spikes = np.load(cell_dir + "/redetected_spikes.npy", allow_pickle=True)
        # print("Cell data total: {} vs {} # traces".format(total, len(spikes)))

        traces = np.load(cell_dir + "/fi_time_v1_traces.npy", allow_pickle=True)
        trace_contrasts = np.load(cell_dir + "/fi_traces_contrasts.npy", allow_pickle=True)
        trace_max_similarity = np.load(cell_dir + "/fi_traces_contrasts_similarity.npy", allow_pickle=True)
        count_good = 0
        count_bad = 0

        threshold_file_path = "data/fi_thresholds.tsv"

        # thresholds_dict = load_fi_thresholds(threshold_file_path)
        # spikes, traces = get_redetected_spikes(cell_dir, 0.2, 0.8, cell_data.get_sampling_interval(), thresholds_dict[cell])
        # print("No preduration:", len(traces))

        contrast_trials = {}
        for i in range(len(traces)):
            differences_max_second_max.append((trace_max_similarity[i][0] - trace_max_similarity[i][1])/ trace_max_similarity[i][0])

            if trace_max_similarity[i][0] > trace_max_similarity[i][1] + 0.15 and trace_max_similarity[i][0] < trace_max_similarity[i][1] + 0.2:
                print("max sim: {:.2f}, {:.2f}".format(trace_max_similarity[i][0], trace_max_similarity[i][1]))

            if trace_max_similarity[i][0] > trace_max_similarity[i][1] + 0.15:
                count_good += 1
                if trace_contrasts[i] not in contrast_trials:
                    contrast_trials[trace_contrasts[i]] = 0
                contrast_trials[trace_contrasts[i]] += 1
                continue


            count_bad += 1






            # count_bad += 1
            # event_offset = max(traces[i][1]) + 0.5
            # fig, axes = plt.subplots(2, 1, sharex="all")
            # axes[0].plot(traces[i][0], traces[i][1])
            # axes[0].eventplot(spikes[i], lineoffsets=event_offset, colors="black")
            #
            # similarity, max_idx, maxima = find_matching_spiketrain(spikes[i], spiketimes, step_size)
            # axes[0].eventplot(spiketimes[max_idx[0]][max_idx[1]], lineoffsets=event_offset + 1, colors="orange")
            #
            # # for o, st in enumerate(spiketimes[trace_contrasts[i]]):
            # #     axes[0].eventplot(st, lineoffsets=event_offset + 1 + o*1, colors="orange")
            #
            # time, v1, eod, local_eod, stimulus = get_ith_trace(cell_dir, i)
            # axes[1].plot(time, local_eod)
            #
            # plt.show()
            # plt.close()



            # t, f = hF.calculate_time_and_frequency_trace(spikes[-1], cell_data.get_sampling_interval())
            # plt.plot(t, f)
            # plt.eventplot(spikes[-1], lineoffsets=max(traces[-1][1]) + 0.5)
            # plt.show()
            # plt.close()

        if count_bad > 0:

            over_seven = 0
            below_three = 0
            for key in contrast_trials.keys():
                if contrast_trials[key] >= 7:
                    over_seven += 1
                if contrast_trials[key] < 3:
                    below_three += 1

            if over_seven < 7:
                full_count += 1
                print(cell)
                print(contrast_trials)
                print("good:", count_good, "bad:", count_bad)
            if below_three > 1:
                contrast_trials_below_three += 1

        # print("good:", count_good, "bad:", count_bad)

    print("Cells less than 7 trials in seven contrasts:", full_count)
    print("Cells less than 3 trials in a contrast:", contrast_trials_below_three)


def get_ith_trace(cell_dir, i):
    count = 0
    for info, key, time, x in Dl.iload_traces(cell_dir, repro="FICurve", before=0.2, after=0.8):

        if '----- Control --------------------------------------------------------' in info[0].keys():
            pre_duration = float(
                info[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
            if pre_duration != 0:
                continue
        elif "preduration" in info[0].keys():
            pre_duration = float(info[0]["preduration"][:-2])
            if pre_duration != 0:
                continue
        elif len(info) == 2 and "preduration" in info[1].keys():
            pre_duration = float(info[1]["preduration"][:-2])
            if pre_duration != 0:
                continue

        if count < i:
            count += 1
            continue
        # print(count)

        # time, v1, eod, local_eod, stimulus
        # print(info)
        # print(key)

        v1 = x[0]
        eod = x[1]
        local_eod = x[2]
        stimulus = x[3]

        return time, v1, eod, local_eod, stimulus


def load_fi_thresholds(threshold_file_path):
    thresholds_dict = {}

    if os.path.exists(threshold_file_path):
        with open(threshold_file_path, "r") as threshold_file:
            for line in threshold_file:
                line = line.strip()
                line = line.split('\t')
                name = line[0]
                bottom_percentile = float(line[1])
                top_percentile = float(line[2])

                thresholds_dict[name] = [bottom_percentile, top_percentile]
                # print("Already done:", name)

    return thresholds_dict


def find_and_save_best_threshold():
    base_path = "data/final/"
    threshold_file_path = "data/fi_thresholds.tsv"
    re_choose_thresholds = False

    thresholds_dict = load_fi_thresholds(threshold_file_path)

    count = 0
    for item in sorted(os.listdir(base_path)):
        if item in thresholds_dict.keys() and not re_choose_thresholds:
            continue
        count += 1

    print("cells to do:", count)

    for item in sorted(os.listdir(base_path)):
        if item in thresholds_dict.keys() and not re_choose_thresholds and not thresholds_dict[item][0] < 10:
            print("Already done:", item)
            continue

        cell_dir = base_path + item
        # starting assumptions:
        standard_top_percentile = 95
        threshold_pairs = [(40, 95), (50, 95), (60, 95)]
        colors = ["blue", "orange", "red"]

        if "thresholds" in item:
            continue

        print(item)
        item_path = base_path + item
        cell_data = CellData(item_path)
        step_size = cell_data.get_sampling_interval()

        trace_pairs = np.load(cell_dir + "/fi_time_v1_traces.npy", allow_pickle=True)
        trace_contrasts = np.load(cell_dir + "/fi_traces_contrasts.npy", allow_pickle=True)
        trace_max_similarity = np.load(cell_dir + "/fi_traces_contrasts_similarity.npy", allow_pickle=True)

        example_trace_pairs = []
        example_contrasts = []

        for i, trace_pair in enumerate(trace_pairs):
            if trace_contrasts[i] not in example_contrasts:
                example_contrasts.append(trace_contrasts[i])
                example_trace_pairs.append(trace_pair)

        example_contrasts, example_trace_pairs = zip(*sorted(zip(example_contrasts, example_trace_pairs)))

        stop = False

        print("Thresholds are:\n ")
        for i in range(len(threshold_pairs)):
            print("{}: {} - {}".format(i, colors[i], threshold_pairs[i]))

        plot_test_thresholds(example_trace_pairs, threshold_pairs, colors, step_size)

        response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)")

        while True:
            if response == "stop":
                stop = True
                break
            elif response.lower().startswith("ok"):
                parts = response.split(" ")
                if len(parts) == 1:
                    print("please specify an index:")
                    response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)")
                    continue
                try:
                    threshold_idx = int(parts[1])
                    break
                except:
                    print("{} could not be parsed as number or ok please try again.".format(response))
                    print("Thresholds are:\n ")
                    for i in range(len(threshold_pairs)):
                        print("{}: {} - {}".format(i, colors[i], threshold_pairs[i]))

                    response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)")

            try:
                parts = response.strip().split(",")
                if len(parts) == 1:
                    extra_pair = (float(parts[0]), standard_top_percentile)
                elif len(parts) == 2:
                    extra_pair = (float(parts[0]), float(parts[1]))
                else:
                    raise ValueError()

            except ValueError as e:
                print("{} could not be parsed as number or ok please try again.".format(response))
                print("Thresholds are:\n ")
                for i in range(len(threshold_pairs)):
                    print("{}: {} - {}".format(i, colors[i], threshold_pairs[i]))

                response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100) or two numbers: bot, top")
                continue

            plot_test_thresholds(example_trace_pairs, threshold_pairs, colors, step_size, extra_pair=extra_pair)

            print("Thresholds are:\n ")
            for i in range(len(threshold_pairs)):
                print("{}: {} - {}".format(i, colors[i], threshold_pairs[i]))

            response = input("Choose: 'ok', 'stop', or a number (bottom threshold 0-100)")

        if stop:
            break

        if threshold_idx < len(threshold_pairs):
            thresholds_dict[item] = [threshold_pairs[threshold_idx][0], threshold_pairs[threshold_idx][1]]
        else:
            thresholds_dict[item] = [extra_pair[0], extra_pair[1]]

    with open(threshold_file_path, "w") as threshold_file:
        for name in sorted(thresholds_dict.keys()):
            line = name + "\t"
            line += str(thresholds_dict[name][0]) + "\t"
            line += str(thresholds_dict[name][1]) + "\t"
            threshold_file.write(line + "\n")


def plot_test_thresholds(trace_pairs, threshold_pairs, colors, step_size, extra_pair=None):

    ncols = int(np.ceil(len(trace_pairs) / 4))
    nrows = int(np.ceil(len(trace_pairs) / ncols))

    fig, axes = plt.subplots(nrows, ncols, sharex="all", figsize=(12, 12))

    for i, (time, v1) in enumerate(trace_pairs):
        line_offset = 0
        c = i % ncols
        r = int(np.floor(i / ncols))

        v1_max = np.max(v1)
        v1_median = np.median(v1)

        axes[r, c].plot(time, v1)
        axes[r, c].plot((time[0], time[-1]), (v1_median, v1_median), color="black")

        v1_part = v1[-int(0.6/step_size):]

        if extra_pair is not None:
            threshold = np.percentile(v1_part, extra_pair[1]) - np.percentile(v1_part, extra_pair[0])
            axes[r, c].plot((time[0], time[-1]), (v1_median+threshold, v1_median+threshold), color="black")
            peaks, _ = detect_peaks(v1, threshold=threshold)
            spikes = [time[idx] for idx in peaks]
            axes[r, c].eventplot(spikes, colors="black", lineoffsets=v1_max + line_offset)
            line_offset += 1

        for j, (bot_perc, top_perc) in enumerate(threshold_pairs):
            threshold = np.percentile(v1_part, top_perc) - np.percentile(v1_part, bot_perc)
            axes[r, c].plot((time[0], time[-1]), (v1_median + threshold, v1_median + threshold), color=colors[j])
            peaks, _ = detect_peaks(v1, threshold=threshold)
            spikes = [time[idx] for idx in peaks]
            axes[r, c].eventplot(spikes, colors=colors[j], lineoffsets=v1_max + line_offset)
            line_offset += 1

    plt.show()
    plt.close()


def test_detected_spiketimes(traces, redetected, spiketimes, step):
    time = traces[0]
    v1 = traces[1]
    plt.plot(traces[0], traces[1])
    plt.eventplot(redetected, colors="red", lineoffsets=max(traces[1]) + 1)
    median = np.median(traces[1])
    last_600_ms = int(np.rint(0.6 / step))
    threshold_last_600 = np.percentile(v1[-last_600_ms:], TOP_PERCENTILE) - np.percentile(v1[-last_600_ms:], BOTTOM_PERCENTILE) * FACTOR
    threshold_normal = np.percentile(v1, 94.5) - np.percentile(v1, 50)
    print("threshold full time  : {:.2f}".format(threshold_normal))
    print("threshold last 600 ms: {:.2f}".format(threshold_last_600))
    peaks, _ = detect_peaks(v1, threshold=threshold_last_600)
    redetected_current_values = [time[idx] for idx in peaks]

    plt.eventplot(redetected_current_values, colors="green", lineoffsets=max(traces[1]) + 2)

    plt.plot((traces[0][0], traces[0][-1]), (median, median), color="black")
    plt.plot((traces[0][0], traces[0][-1]), (median+threshold_normal, median+threshold_normal), color="black")
    plt.plot((traces[0][0], traces[0][-1]), (median+threshold_last_600, median+threshold_last_600), color="grey")
    for i, spiketrain in enumerate(spiketimes):
        plt.eventplot(spiketrain, colors="black", lineoffsets=max(traces[1]) + 3 + i)

    plt.show()
    plt.close()


def plot_percentiles(trace):
    steps = np.arange(0, 100.1, 0.5)
    percentiles = np.percentile(trace, steps)

    plt.plot(steps, percentiles)
    plt.show()
    plt.close()


def get_unsorted_spiketimes(fi_file):
    spiketimes = []
    for metadata, key, data in Dl.iload(fi_file):
        spike_time_data = data[:, 0] / 1000
        spiketimes.append(spike_time_data)
    return spiketimes


def find_matching_spiketrain(redetected, cell_spiketrains, step_size):
    # redetected_idices = [int(np.rint(s / step_size)) for s in redetected]
    spikes_dict = {}
    for s in redetected:
        idx = int(np.rint(s / step_size))
        spikes_dict[idx] = True
        spikes_dict[idx+1] = True
        spikes_dict[idx-1] = True
    similarity = np.zeros((len(cell_spiketrains), max([len(contrast_list) for contrast_list in cell_spiketrains])))
    maximum = -1
    max_idx = (-1, -1)
    for i, contrast_list in enumerate(cell_spiketrains):
        for j, cell_spiketrain in enumerate(contrast_list):

            count = 0
            cell_spike_indices = [int(np.rint(s / step_size)) for s in cell_spiketrain]

            # plt.plot(cell_spiketrain, cell_spike_indices, '.')
            # plt.plot(redetected, redetected_idices, '.')
            # plt.show()
            # plt.close()

            for spike in cell_spiketrain:
                idx = int(np.rint(spike / step_size))
                if idx in spikes_dict:
                    count += 1
            similarity[i, j] = count / len(cell_spiketrain)
            if similarity[i, j] > maximum:
                maximum = similarity[i, j]
                max_idx = (i, j)

    # plt.imshow(similarity)
    # plt.show()
    # plt.close()
    flattened = similarity.flatten()
    sorted_flattened = sorted(flattened)
    second_max = sorted_flattened[-2]
    if maximum < 0.5:
        print("Identification: max_sim: {:.2f} vs {:.2f} second max; Diff: {} worked".format(maximum, second_max, maximum - second_max))
    return similarity, max_idx, (maximum, second_max)


def get_redetected_spikes(cell_dir, before, after, step, threshold_pair):
    spikes_list = []
    traces = []
    count = 1
    for info, key, time, x in Dl.iload_traces(cell_dir, repro="FICurve", before=before, after=after):
        # print(count)

        if '----- Control --------------------------------------------------------' in info[0].keys():
            pre_duration = float(
                info[0]["----- Pre-Intensities ------------------------------------------------"]["preduration"][:-2])
            if pre_duration != 0:
                continue
        elif "preduration" in info[0].keys():
            pre_duration = float(info[0]["preduration"][:-2])
            if pre_duration != 0:
                continue
        elif len(info) == 2 and "preduration" in info[1].keys():
            pre_duration = float(info[1]["preduration"][:-2])
            if pre_duration != 0:
                continue
        count += 1

        # time, v1, eod, local_eod, stimulus
        # print(key)
        # print(info)
        v1 = x[0]
        # percentiles = np.arange(0.0, 101, 1)
        # plt.plot(percentiles, np.percentile(v1, percentiles))
        # plt.show()
        # plt.close()

        if len(v1) > 15/step:
            print("Skipping Fi-Curve trace longer than 15 seconds!")
            continue
        if len(v1) > 3/step:
            print("Warning: A FI-Curve trace is longer than 3 seconds.")

        if after < 0.8:
            print("Why the f is the after stimulus time shorter than 0.8s ???")
            raise ValueError("Safety error: check where the after stimulus time comes from.")

        last_about_600_ms = int(np.rint((after-0.2)/step))
        top = np.percentile(v1[-last_about_600_ms:], threshold_pair[1])
        bottom = np.percentile(v1[-last_about_600_ms:], threshold_pair[0])
        threshold = (top - bottom)

        peaks, _ = detect_peaks(v1, threshold=threshold)
        spikes = [time[idx] for idx in peaks]
        spikes_list.append(np.array(spikes))
        # eod = x[1]
        # local_eod = x[2]
        stimulus = x[3]
        # if count % 5 == 0:
        #     plt.eventplot(spikes, colors="black", lineoffsets=max(v1) + 1)
        #     plt.plot(time, v1)
        #     median = np.median(v1)
        #     plt.plot((time[0], time[-1]), (median, median), color="grey")
        #     plt.plot((time[0], time[-1]), (median+threshold, median+threshold), color="grey")
        #     plt.show()
        #     plt.close()

        # print(key[5])
        # if "rectangle" not in key[5] and "FICurve" not in key[5][35]:
        #     raise ValueError("No value in key 5 is rectangle:")

        traces.append([np.array(time), np.array(v1)])

    return np.array(spikes_list), traces


if __name__ == '__main__':
    main()