from ModelFit import get_best_fit, ModelFit
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from Baseline import BaselineCellData


SAVE_DIR = "results/lab_rotation/"


def main():

    res_folder = "results/final_2/"

    # save_model_parameters(res_folder)
    # save_cell_info(res_folder)

    # test_save_cell_info()


def save_model_parameters(res_folder):
    cells = []
    eod_freqs = []
    parameters = []

    for cell in sorted(os.listdir(res_folder)):
        cell_dir = res_folder + cell

        model = get_best_fit(cell_dir, use_comparable_error=False)
        cells.append(cell)
        eod_freqs.append(model.get_cell_data().get_eod_frequency())
        parameters.append(model.get_final_parameters())

    save_csv(SAVE_DIR + "models.csv", cells, eod_freqs, parameters)


def test_save_cell_info():
    for cell in sorted(os.listdir(SAVE_DIR)):
        cell_dir = SAVE_DIR + cell + "/"
        if not os.path.isdir(cell_dir):
            continue

        fi_frame = pd.read_csv(cell_dir + "fi_curve_info.csv")

        plt.plot(fi_frame["contrast"], fi_frame["f_inf"], 'o')
        plt.plot(fi_frame["contrast"], fi_frame["f_zero"], '+')
        plt.show()
        plt.close()

        count = 1
        spike_file = "baseline_spikes_trial_{}.npy".format(count)
        while os.path.exists(cell_dir + spike_file):
            spiketimes = np.load(cell_dir + spike_file) * 1000

            plt.hist(np.diff(spiketimes), bins=np.arange(0, 50, 0.1))
            plt.show()
            plt.close()

            count += 1
            spike_file = "baseline_spikes_trial_{}.npy".format(count)


def save_cell_info(res_folder):
    for cell in sorted(os.listdir(res_folder)):
        cell_dir = res_folder + cell

        fit = get_best_fit(cell_dir, use_comparable_error=False)

        save_path = SAVE_DIR + cell + "/"

        if not os.path.exists(save_path):
            os.mkdir(save_path)

        # fi-curve
        cell_data = fit.get_cell_data()
        f_zeros = fit.get_cell_f_zero_values()
        f_infs = fit.get_cell_f_inf_values()
        contrasts = cell_data.get_fi_contrasts()

        data_array = np.array([contrasts, f_infs, f_zeros]).T
        fi_frame = pd.DataFrame(data_array, columns=["contrast", "f_inf", "f_zero"])
        fi_frame.to_csv(save_path + "fi_curve_info.csv")

        spikes = cell_data.get_base_spikes()
        for i, spike_list in enumerate(spikes):
            spike_array = np.array(spike_list)
            np.save(save_path + "baseline_spikes_trial_{}.npy".format(i+1), spike_array)


def save_csv(file, cells, eod_freqs, parameters):
    keys = sorted(parameters[0].keys())
    with open(file, "w") as file:
        header = "cell,EODf"
        for k in keys:
            if k == "refractory_period":
                header += ",ref_period"
            elif k == "step_size":
                header += ",deltat"
            else:
                header += ",{}".format(k)
        file.write(header + "\n")

        for i in range(len(cells)):
            line = "{},{:.2f}".format(cells[i], eod_freqs[i])
            for k in keys:
                line += ",{}".format(parameters[i][k])
            file.write(line + "\n")


if __name__ == '__main__':
    main()