from CellData import CellData, icelldata_of_dir
from os import listdir
import os


def main():
    # choose_thresholds()
    precalculate_baseline_spiketimes()


def precalculate_baseline_spiketimes():
    threshold_file_path = "data/thresholds.tsv"
    thresholds_dict = {}

    if os.path.exists(threshold_file_path):
        with open(threshold_file_path, "r") as threshold_file:
            for line in threshold_file:
                line = line.strip()
                line = line.split('\t')
                name = line[0]
                thresh = float(line[1])
                min_length = int(line[2])
                step_size = int(line[3])

                thresholds_dict[name] = [thresh, min_length, step_size]

    for cell_data in icelldata_of_dir("data/final/", test_for_v1_trace=False):
        name = os.path.basename(cell_data.get_data_path())

        if name not in thresholds_dict.keys():
            print("key missing: {}".format(name))
            continue

        thresh = thresholds_dict[name][0]
        min_length = thresholds_dict[name][1]
        split_step_size = thresholds_dict[name][2]

        cell_data.get_base_spikes(threshold=thresh, min_length=min_length, split_step=split_step_size, re_calculate=True)


def choose_thresholds():
    base_path = "data/final/"
    threshold_file_path = "data/thresholds.tsv"
    re_choose_thresholds = False
    thresholds_dict = {}

    if os.path.exists(threshold_file_path):
        with open(threshold_file_path, "r") as threshold_file:
            for line in threshold_file:
                line = line.strip()
                line = line.split('\t')
                name = line[0]
                thresh = float(line[1])
                if len(line) > 2:
                    min_length = int(line[2])
                    step_size = int(line[3])
                    thresholds_dict[name] = [thresh, min_length, step_size]
                    print("Already done:", name)
                else:
                    thresholds_dict[name] = [thresh]
                    print("Already done (no len/step): ", name)

    count = 0
    for item in sorted(listdir(base_path)):
        if item in thresholds_dict.keys() and thresholds_dict[item][0] != 99 and not re_choose_thresholds:
            continue

        count += 1
    print("cells to do:", count)

    for item in sorted(listdir(base_path)):
        # starting assumptions:
        thresh = 2.5
        min_split_length = 5000
        split_step_size = 1000

        if "thresholds" in item:
            continue

        if item in thresholds_dict.keys() and thresholds_dict[item][0] != 99 and not re_choose_thresholds:
            if len(thresholds_dict[item]) == 1:
                thresholds_dict[item] = [thresholds_dict[item][0], min_split_length, split_step_size]
            continue
        print(item)
        item_path = base_path + item
        data = CellData(item_path)

        trace = data.get_base_traces(trace_type=data.V1)
        if len(trace) == 0:
            print("NO V1 TRACE FOUND: ", item_path)
            continue

        data.get_base_spikes(thresh, min_length=min_split_length, split_step=split_step_size, re_calculate=True,
                                      only_first=True)
        stop = False

        print("Threshold was {:.2f}, Min Length was {:.0f}, Split step size was {:.0f}".format(thresh, min_split_length,
                                                                                               split_step_size))

        response = input(
            "Choose: 'ok', 'stop', or a number (threshold) or three numbers (threshold, minlength, step_size) seperated with commas")

        while response != "ok":
            if response == "stop":
                stop = True
                break
            try:
                parts = response.split(",")
                if len(parts) == 1:
                    thresh = float(response)
                else:
                    thresh = float(parts[0])
                    min_split_length = int(parts[1])
                    split_step_size = int(parts[2])
            except ValueError as e:
                print("{} could not be parsed as number or ok please try again.".format(response))
                print("Threshold was {:.2f}, Min Length was {:.0f}, Split step size was {:.0f}".format(thresh,
                                                                                                       min_split_length,
                                                                                                       split_step_size))
                response = input(
                    "Choose: 'ok', 'stop', or a number (threshold) or three numbers (threshold, minlength, step_size) seperated with commas")
                continue

            data.get_base_spikes(thresh, min_length=min_split_length, split_step=split_step_size, re_calculate=True,
                                 only_first=True)
            print(
                "Threshold was {:.2f}, Min Length was {:.0f}, Split step size was {:.0f}".format(thresh,
                                                                                                 min_split_length,
                                                                                                 split_step_size))
            response = input(
                "Choose: 'ok', 'stop', or a number (threshold) or three numbers (threshold, minlength, step_size) seperated with commas")

        if stop:
            break

        thresholds_dict[item] = [thresh, min_split_length, split_step_size]

    with open(threshold_file_path, "w") as threshold_file:
        for name in sorted(thresholds_dict.keys()):
            if len(thresholds_dict[name]) == 1:
                threshold_file.write(name + "\t" + str(thresholds_dict[name][0]) + "\n")
            else:
                line = name + "\t"
                line += str(thresholds_dict[name][0]) + "\t"
                line += str(thresholds_dict[name][1]) + "\t"
                line += str(thresholds_dict[name][2]) + "\n"
                threshold_file.write(line)


if __name__ == '__main__':
    main()