from ModelFit import get_best_fit
import numpy as np
import os
import pandas
import matplotlib.pyplot as plt
import statsmodels.api as sm


def main():


    folder = "results/final_1/"
    # input len(cells) x len(variables) 2D array
    variable_order = ['Burstiness', 'baseline_frequency', 'coefficient_of_variation',
                      'f_inf_slope', 'f_zero_slope', 'serial_correlation', 'vector_strength']
    behaviour, error = get_variables(folder, variable_order)


    df_behaviour = pandas.DataFrame(behaviour, columns=variable_order)
    # print(df)
    gamma_glm = sm.GLM(error, df_behaviour, sm.families.Gamma())
    fitted_model = gamma_glm.fit()

    params = fitted_model.params
    p_values = fitted_model.pvalues
    print(p_values)
    predicted = fitted_model.predict()
    # for i in range(len(predicted)):
    #     print("err: {:.2f} - {:.2f} prediction".format(error[i], predicted[i]))
    print(fitted_model.summary())


    pass


def get_variables(folder, order):
    variables = []
    error_values = []
    for cell in sorted(os.listdir(folder)):
        fit = get_best_fit(folder + cell)
        error = fit.get_error_value()
        error_values.append(error)
        cell_behaviour, _ = fit.get_behaviour_values()
        cell_behaviour_variables = []

        for b in order:
            cell_behaviour_variables.append(cell_behaviour[b])

        variables.append(np.array(cell_behaviour_variables, dtype=np.float64))

    return np.array(variables), np.array(error_values, dtype=np.float64)


def till_shorthand():
    # logit_GLM = sm.GLM(df[ < result >], df[ < params >], family = sm.families.Binomial())
    # fitted_model = logit_GLM.fit()
    #
    # fitted_model.predict()
    # # --> gets you the predicted values for df[<results>] based on the fitted model
    #
    #
    # fitted_model.params
    # # --> coeff of params(sorted by input)
    #
    # fitted_model.pvalues
    # # --> selbsterklärend
    #
    # fitted_model.summary()
    # # ODER
    # fitted_model.summary2()
    # # --> überblick über das gefittete model
    pass


if __name__ == '__main__':
    main()