from CellData import icelldata_of_dir
from Baseline import BaselineCellData
from FiCurve import FICurveCellData
import os


def main():

    # plot_visualizations("cells/")
    full_overview("cells/master_table.csv", "cells/")


def move_rejected_cell_data():
    count = 0
    jump_to = 0
    negative_contrast_rel = 0
    cell_list = []
    for d in icelldata_of_dir("invivo_data/"):
        count += 1
        if count < jump_to:
            continue
        print(d.get_data_path())
        base = BaselineCellData(d)
        base.load_values(d.get_data_path())
        ficurve = FICurveCellData(d, d.get_fi_contrasts(), d.get_data_path())

        if ficurve.get_f_inf_slope() < 0:
            negative_contrast_rel += 1

            print("negative f_inf slope")
            cell_list.append(os.path.abspath(d.get_data_path()))

    for c in cell_list:
        if os.path.exists(c):
            print("Source: ", c)
            destination = os.path.abspath("rejected_cells/negative_slope_f_inf/" + os.path.basename(c))
            print("destination: ", destination)
            print()
            os.rename(c, destination)
    print("Number: " + str(negative_contrast_rel))


def plot_visualizations(folder_path):

    for cell_data in icelldata_of_dir("invivo_data/"):

        name = os.path.split(cell_data.get_data_path())[-1]
        print(name)
        save_path = folder_path + name + "/"
        if not os.path.exists(save_path):
            os.mkdir(save_path)

        baseline = BaselineCellData(cell_data)
        baseline.plot_baseline(save_path)
        baseline.plot_serial_correlation(10, save_path)
        baseline.plot_polar_vector_strength(save_path)
        baseline.plot_interspike_interval_histogram(save_path)

        ficurve = FICurveCellData(cell_data, cell_data.get_fi_contrasts())
        ficurve.plot_fi_curve(save_path)


def full_overview(save_path_table, folder_path):
    with open(save_path_table, "w") as table:
        table.write("Name, Path, Baseline Frequency Hz,Vector Strength, serial correlation lag=1,"
                    " serial correlation lag=2, burstiness, coefficient of variation,"
                    " fi-curve inf slope, fi-curve zero slope at straight, contrast at fi-curve zero straight\n")

        # add contrasts, f-inf values, f_zero_values
        count = 0
        start = 0
        for cell_data in icelldata_of_dir("invivo_data/"):
            count += 1
            if count < start:
                continue
            save_dir = cell_data.get_data_path()
            name = os.path.split(cell_data.get_data_path())[-1]
            line = name + ","
            line += cell_data.get_data_path() + ","

            baseline = BaselineCellData(cell_data)
            if not baseline.load_values(save_dir):
                baseline.save_values(save_dir)

            line += "{:.1f},".format(baseline.get_baseline_frequency())
            line += "{:.2f},".format(baseline.get_vector_strength())
            sc = baseline.get_serial_correlation(2)
            line += "{:.2f},".format(sc[0])
            line += "{:.2f},".format(sc[1])
            line += "{:.2f},".format(baseline.get_burstiness())
            line += "{:.2f},".format(baseline.get_coefficient_of_variation())

            ficurve = FICurveCellData(cell_data, cell_data.get_fi_contrasts(), save_dir)

            line += "{:.2f},".format(ficurve.get_f_inf_slope())
            line += "{:.2f}\n".format(ficurve.get_f_zero_fit_slope_at_straight())
            line += "{:.2f}\n".format(ficurve.f_zero_fit[3])
            table.write(line)

            name = os.path.split(cell_data.get_data_path())[-1]
            print(name)
            save_path = folder_path + name + "/"
            if not os.path.exists(save_path):
                os.mkdir(save_path)

            baseline.plot_baseline(save_path)
            baseline.plot_serial_correlation(10, save_path)
            baseline.plot_polar_vector_strength(save_path)
            baseline.plot_interspike_interval_histogram(save_path)

            ficurve.plot_fi_curve(save_path)


if __name__ == '__main__':
    main()