from FiCurve import FICurve
from CellData import CellData
from models.LIFACnoise import LifacNoiseModel
import helperFunctions as hF
import functions as fu
import numpy as np


CELL_PATH = "data/2012-06-27-ah-invivo-1/"
# current parameters from fit with folus on improving model f0 response
# MODEL_PARAMETERS = {'step_size': 5e-05, 'mem_tau': 0.042542178602690675, 'v_base': 0, 'v_zero': 0, 'threshold': 1, 'v_offset': -111.328125, 'input_scaling': 388.9439738592248, 'delta_a': 0.05513136301255167, 'tau_a': 0.1017720885626184, 'a_zero': 2, 'noise_strength': 0.01740931732483443}
# parameters fit with fixed adaption focus on serial cor and f_inf_slope
MODEL_PARAMETERS = {'step_size': 5e-05, 'mem_tau': 0.03547683648372142, 'v_base': 0, 'v_zero': 0, 'threshold': 1, 'v_offset': -43.75, 'input_scaling': 162.97353975832954, 'delta_a': 0.024625305808413798, 'tau_a': 0.07632538029074364, 'a_zero': 2, 'noise_strength': 0.00408169739163286}

SAVE_PATH = "results/test/"


def main():
    cell_data = CellData(CELL_PATH)
    fi_curve = FICurve(cell_data)
    model = LifacNoiseModel(MODEL_PARAMETERS)

    print_characteristics(cell_data, fi_curve, model)
    plot_fi_curve(cell_data, fi_curve, model)


def print_characteristics(cell_data: CellData, fi_curve: FICurve, model):
    sc_max_lag = 1
    eod_freq = cell_data.get_eod_frequency()

    fi_contrasts = fi_curve.stimulus_value

    cell_bf = cell_data.get_base_frequency()
    cell_sc = cell_data.get_serial_correlation(sc_max_lag)
    cell_vs = cell_data.get_vector_strength()

    cell_f_inf_slope = fi_curve.get_f_infinity_slope()
    cell_f_inf_values = fi_curve.f_infinities

    cell_f_zero_slope = fi_curve.get_fi_curve_slope_of_straight()
    cell_f_zero_values = fi_curve.f_zeros

    # calculate Model characteristics:
    baseline_freq, vector_strength, serial_correlation = model.calculate_baseline_markers(eod_freq, sc_max_lag)

    f_baselines, f_zeros, f_infinities = model.calculate_fi_curve(fi_contrasts, eod_freq)
    f_infinities_fit = hF.fit_clipped_line(fi_contrasts, f_infinities)
    f_infinities_slope = f_infinities_fit[0]

    f_zeros_fit = hF.fit_boltzmann(fi_contrasts, f_zeros)
    f_zero_slope = fu.full_boltzmann_straight_slope(f_zeros_fit[0], f_zeros_fit[1], f_zeros_fit[2], f_zeros_fit[3])

    print_value("Base frequency", cell_bf, baseline_freq)
    for i in range(sc_max_lag):
        print_value("Serial correlation lag {}".format(i+1), cell_sc[i], serial_correlation[i])
    print_value("Vector strength", cell_vs, vector_strength)
    print_value("f_inf slope", cell_f_inf_slope, f_infinities_slope)
    print_f_values("f_inf values", cell_f_inf_values, f_infinities)
    print_value("f_zero slope", cell_f_zero_slope, f_zero_slope)
    print_f_values("f_zero values", cell_f_zero_values, f_zeros)


def print_value(name, cell_value, model_value):
    print("{} - expected: {:.2f}, current: {:.2f}, %-error: {:.3f}".format(name,
        cell_value, model_value, (model_value-cell_value)/cell_value))


def print_f_values(name, cell_values, model_values):
    cell_values = np.array(cell_values)
    model_values = np.array(model_values)

    mean_perc_error = 0
    for c,m in zip(cell_values, model_values):
        mean_perc_error += (m-c)/ c

    mean_perc_error = mean_perc_error/ len(cell_values)

    print("{}:\nexpected:".format(name), np.around(cell_values), "\ncurrent: ", np.around(model_values),
    "\nmean %-error: {:.3f}".format(mean_perc_error))


def plot_fi_curve(cell_data, fi_curve, model, save=False):
    f_b, f_zero, f_inf = model.calculate_fi_curve(cell_data.get_fi_contrasts(), cell_data.get_eod_frequency())
    if save:
        fi_curve.plot_fi_curve(savepath=SAVE_PATH, comp_f_baselines=f_b, comp_f_zeros=f_zero, comp_f_infs=f_inf)
    else:
        fi_curve.plot_fi_curve(comp_f_baselines=f_b, comp_f_zeros=f_zero, comp_f_infs=f_inf)


if __name__ == '__main__':
    main()