from CellData import icelldata_of_dir, CellData
from DataParserFactory import DatParser
import numpy as np
import os
import matplotlib.pyplot as plt
import pyrelacs.DataLoader as Dl

data_save_path = "test_routines/test_files/"
read = False


def main():
    # test_kraken_files()
    test_for_species()


def test_for_species():

    directory = "invivo_data/"
    sorted_cells = {}
    error_cells = []
    for cell in os.listdir(directory):
        if "thresh" in cell:
            continue
        cell_path = os.path.join(directory, cell)
        # print(cell_path)
        info_file = os.path.join(cell_path, "info.dat")

        for metadata in Dl.load(info_file):
            if "Species" in metadata.keys():
                species = metadata["Species"]
            else:
                species = metadata["Subject"]["Species"]
            if species not in sorted_cells.keys():
                sorted_cells[species] = []
            sorted_cells[species].append(cell_path)


    print("Errors:", len(error_cells))
    for species in sorted_cells.keys():
        print("{}: {}".format(species, len(sorted_cells[species])))


    print()
    print("errors:")
    for cell in error_cells:
        print(cell)

    # print()
    # print("eigen:")
    # for cell in sorted_cells["Eigenmannia virescens"]:
    #     print(cell)
    #
    # print()
    # print("albi:")
    # for cell in sorted_cells["Apteronotus albifrons"]:
    #     print(cell)

def test_kraken_files():
    if read:
        directory = "/mnt/invivo_data/"
        fi_curve_min_contrasts = 7
        fi_curve_min_trials = 7

        baseline_min_duration = 30
        files = []
        baseline = []
        ficurve = []
        accepted = []
        count = 0
        for data_dir in os.listdir(directory):
            data_dir = os.path.join(directory, data_dir)
            if not os.path.isdir(data_dir):
                continue
            try:
                parser = DatParser(data_dir)

                print(data_dir)
                baseline_lengths = parser.get_baseline_length()
                baseline_good = max(baseline_lengths) >= baseline_min_duration
                contrasts = parser.get_fi_curve_contrasts()
                if len(contrasts) < fi_curve_min_contrasts:
                    fi_curve_good = False
                else:
                    intensities_with_enough_trials = contrasts[:, 0][contrasts[:, 1] >= fi_curve_min_trials]

                    fi_curve_good = len(intensities_with_enough_trials) >= fi_curve_min_contrasts

                if fi_curve_good and baseline_good:
                    count += 1
                    print("good")
                    accepted.append(True)
                else:
                    print("bad")
                    accepted.append(False)
                files.append(data_dir)
                baseline.append(baseline_lengths)
                ficurve.append(contrasts)
            except RuntimeError as e:
                print(data_dir)
                print("bad")
                accepted.append(False)
                files.append(data_dir)
                baseline.append([])
                ficurve.append([])

        files = np.array(files)
        baseline = np.array(baseline)
        ficurve = np.array(ficurve)
        accepted = np.array(accepted)

        np.save(data_save_path + "files", files)
        np.save(data_save_path + "baseline", baseline)
        np.save(data_save_path + "ficurve", ficurve)
        np.save(data_save_path + "accepted", accepted)
        print("Total good:", count)

    else:
        files = np.load(data_save_path + "files.npy", allow_pickle=True)
        baseline = np.load(data_save_path + "baseline.npy", allow_pickle=True)
        ficurve = np.load(data_save_path + "ficurve.npy", allow_pickle=True)
        accepted = np.load(data_save_path + "accepted.npy", allow_pickle=True)

        print(np.sum(accepted))
        with open("test_routines/data_files.txt", "w") as file:

            for i in range(len(files)):
                if accepted[i]:
                    file.write(files[i] + "\n")

        quit()

        min_contrasts = 7
        min_trials = 7
        min_baseline = 30
        print("min_baseline: {:}, min_contrasts: {:}, min_trials: {:}".format(min_baseline, min_contrasts, min_trials))
        # bins = np.arange(0, 100, 1)
        # plt.hist([max(x) for x in baseline if len(x) > 0], bins=bins)
        # plt.show()
        # plt.close()
        good_cells = []
        ints_with_enough_trials = []
        for i, contrasts in enumerate(ficurve):
            if len(baseline[i]) <= 0 or max(baseline[i]) < min_baseline:
                continue
            count = 0
            if len(contrasts) == 0:
                continue

            for intensity in contrasts:
                if intensity[1] >= min_trials:
                    count += 1
            ints_with_enough_trials.append(count)

        bins = np.arange(0.5, 20.5, 1)
        points = plt.hist(ints_with_enough_trials, bins=bins)
        print(sum(points[0][min_contrasts-1:]))
        #plt.show()
        #plt.close()

        count = 0
        all_cells = 0
        for cell_data in icelldata_of_dir("data/", False):
            all_cells += 1
            if max(cell_data.get_baseline_length()) < min_baseline:
                continue

            contrasts = cell_data.get_fi_curve_contrasts_with_trial_number()
            c_count = 0
            for c in contrasts:
                if c[1] >= min_trials:
                    c_count += 1

            if c_count < min_contrasts:
                continue

            count += 1

        print("Fullfilled by {:} of {:} test cells".format(count, all_cells))


if __name__ == '__main__':
    main()