from parser.CellData import icelldata_of_dir, CellData
import numpy as np
import os
import pyrelacs.DataLoader as Dl


def main():
    fish_info()
    # eod_info()
    # fi_recording_times()


def eod_info():
    cells = []
    for item in sorted(os.listdir("data/final/")):
        cells.append(os.path.join("data/final/", item))


    eod_freq = []

    for cell in cells:
        data = CellData(cell)
        eod_f = data.get_eod_frequency()
        if not np.isnan(eod_f):
            eod_freq.append(eod_f)
        else:
            print(cell)

    print("eod Freq: min {}, max {}, mean: {:.2f}, std: {:.2f}".format(min(eod_freq), max(eod_freq), np.mean(eod_freq),
                                                                       np.std(eod_freq)))

def fish_info():
    cells = []
    for item in sorted(os.listdir("data/final/")):
        cells.append(os.path.join("data/final/", item))

    cell_type = []
    weight = []
    size = []
    eod_freq = []
    preparation = []
    for cell in cells:
        info_file = os.path.join(cell, "info.dat")
        for metadata in Dl.load(info_file):
            if "CellType" not in metadata.keys():
                cell_type.append(metadata["Cell"]["CellType"])
                if cell_type[-1] != "P-unit":
                    print("not P-unit?", cell)
                if "Weight" in metadata["Subject"].keys():
                    weight.append(float(metadata["Subject"]["Weight"][:-1]))
                size.append(float(metadata["Subject"]["Size"][:-2]))
                if "CellProperties" in metadata.keys():
                    eod_freq.append(float(metadata["CellProperties"]["EOD Frequency"][:-2]))
                elif "Cell properties" in metadata.keys():
                    eod_freq.append(float(metadata["Cell properties"]["EOD Frequency"][:-2]))
                preparation.append(metadata["Preparation"])
                # print(metadata)
            else:
                cell_type.append(metadata["CellType"])
                if cell_type[-1] != "P-unit":
                    print("not P-unit?", cell)

                weight.append(float(metadata["Weight"][:-1]))
                size.append(float(metadata["Size"][:-2]))
                # 'LocalAnaesthesia': 'true', 'AnaestheticDose': '120mg/l', 'Anaesthetic': 'MS 222', 'LocalAnaesthetic': 'Lidocaine', 'Anaesthesia': 'true', 'Type': 'in vivo', 'Immobilization': 'Tubocurarin'
                eod_freq.append(float(metadata["EOD Frequency"][:-2]))
                prep_dict = {}
                for key in ('LocalAnaesthesia', 'AnaestheticDose', 'Anaesthetic', 'LocalAnaesthetic', 'Anaesthesia', 'Immobilization'):
                    prep_dict[key] = metadata[key]
                preparation.append(prep_dict)
                # print(metadata)

    print("Size: min {}, max {}".format(min(size), max(size)))
    print("weight: min {}, max {}".format(min(weight), max(weight)))
    print("eod Freq: min {}, max {}, mean: {:.2f}, std: {:.2f}".format(min(eod_freq), max(eod_freq), np.mean(eod_freq), np.std(eod_freq)))
    print("anaesthetics:", np.unique([x['Anaesthetic'] for x in preparation]))
    print("anaesthetic dosages:", np.unique([x['AnaestheticDose'] for x in preparation]))
    print("local anaesthetic:", np.unique([x['LocalAnaesthesia'] for x in preparation]))
    print("Immobilization:", np.unique([x['Immobilization'] for x in preparation]))


def fi_recording_times():

    recording_times = []
    for cell_data in icelldata_of_dir("data/invivo/", test_for_v1_trace=False):
        # time_start, stimulus_start, stimulus_duration, after_stimulus_duration
        recording_times.append(cell_data.get_recording_times())
    for cell_data in icelldata_of_dir("data/invivo_bursty/", test_for_v1_trace=False):
        # time_start, stimulus_start, stimulus_duration, after_stimulus_duration
        recording_times.append(cell_data.get_recording_times())

    recording_times = np.array(recording_times)

    time_starts = recording_times[:, 0]
    stimulus_starts = recording_times[:, 1]
    stimulus_durations = recording_times[:, 2]
    after_durations = recording_times[:, 3]

    print("Fi-curve stimulus recording times:")
    print("time_starts:", np.unique(time_starts))
    print("stimulus_starts:", np.unique(stimulus_starts))
    unique_durations = np.unique(stimulus_durations)
    print("stimulus_durations:", unique_durations)

    for d in unique_durations:
        print("cells with stimulus duration {}: {}".format(d, np.sum(stimulus_durations == d)))

    print("after_durations:", np.unique(after_durations))


def sampling_intervals():
    intervals = []
    for cell_data in icelldata_of_dir("data/invivo/", test_for_v1_trace=False):
        # time_start, stimulus_start, stimulus_duration, after_stimulus_duration
        intervals.append(cell_data.get_sampling_interval())
    for cell_data in icelldata_of_dir("data/invivo_bursty/", test_for_v1_trace=False):
        # time_start, stimulus_start, stimulus_duration, after_stimulus_duration
        intervals.append(cell_data.get_sampling_interval())

    print(np.unique(intervals))


if __name__ == '__main__':
    main()