from models.LIFACnoise import LifacNoiseModel
from Baseline import BaselineModel
from FiCurve import FICurveModel
import numpy as np
import matplotlib.pyplot as plt
import copy
import os


SEARCH_WIDTH = 1.1
SEARCH_PRECISION = 1
CONTRASTS = np.arange(-0.4, 0.45, 0.1)


def main():
    model_parameters1 = {'threshold': 1,
                         'step_size': 5e-05,
                         'a_zero': 2,
                         'delta_a': 0.2032269898801589,
                         'mem_tau': 0.011314027210564803,
                         'noise_strength': 0.056724809998220195,
                         'v_zero': 0,
                         'v_base': 0,
                         'tau_a': 0.05958195972016753,
                         'input_scaling': 119.81500448274554,
                         'dend_tau': 0.0027746086464721723,
                         'v_offset': -24.21875}

    model_parameters2 = {'v_offset': -15.234375, 'input_scaling': 64.94152780134829, 'step_size': 5e-05, 'a_zero': 2,
                         'threshold': 1, 'v_base': 0, 'delta_a': 0.04763179657857666, 'tau_a': 0.07891848949732623,
                         'mem_tau': 0.004828473985707999, 'noise_strength': 0.017132801387559883,
                         'v_zero': 0, 'dend_tau': 0.0015230454266819539}

    parameters_to_test = ["input_scaling", "dend_tau", "mem_tau", "noise_strength", "v_offset", "delta_a", "tau_a"]
    effect_data = []
    for p in parameters_to_test:
        print("Working on parameter " + p)
        effect_data.append(test_parameter_effect(model_parameters2, p))

    plot_effects(effect_data, "./figures/variable_effect/")


def test_parameter_effect(model_parameters, test_parameter):
    model_parameters = copy.deepcopy(model_parameters)
    start_value = model_parameters[test_parameter]

    start = start_value*(1/SEARCH_WIDTH)
    end = start_value*SEARCH_WIDTH
    step = (end - start) / SEARCH_PRECISION
    values = np.arange(start, end+step, step)

    bf = []
    vs = []
    sc = []
    cv = []

    f_inf_s = []
    f_inf_v = []
    f_zero_s = []
    f_zero_v = []
    fi_curves = []
    broken_i = []

    for i in range(len(values)):
        model_parameters[test_parameter] = values[i]
        model = LifacNoiseModel(model_parameters)

        fi_curve = FICurveModel(model, CONTRASTS, 600, trials=1)
        fi_curves.append(fi_curve)
        f_inf_s.append(fi_curve.get_f_inf_slope())
        f_inf_v.append(fi_curve.get_f_inf_frequencies())
        f_zero_s.append(fi_curve.get_f_zero_fit_slope_at_stimulus_value(0.1))
        f_zero_v.append(fi_curve.get_f_zero_frequencies())

        if not os.path.exists("./figures/f_point_detection/"):
            os.makedirs("./figures/f_point_detection/")

        detection_save_path = "./figures/f_point_detection/{}_{:.4f}/".format(test_parameter, values[i])
        if not os.path.exists(detection_save_path):
            os.makedirs(detection_save_path)

        fi_curve.plot_f_point_detections(detection_save_path)

        baseline = BaselineModel(model, 600, trials=1)
        bf.append(baseline.get_baseline_frequency())
        vs.append(baseline.get_vector_strength())
        sc.append(baseline.get_serial_correlation(2))
        cv.append(baseline.get_coefficient_of_variation())

    values = list(values)
    if len(broken_i) > 0:
        broken_i = sorted(broken_i, reverse=True)
        for i in broken_i:
            del values[i]

    return ParameterEffectData(fi_curves, values, test_parameter, bf, vs, sc, cv, f_inf_s, f_inf_v, f_zero_s, f_zero_v)
    # plot_effects(values, test_parameter, bf, vs, sc, cv, f_inf_s, f_inf_v, f_zero_s, f_zero_v)


def plot_effects(par_effect_data_list, save_path=None):

    names = ("bf", "vs", "sc", "cv", "f_inf_s", "f_inf_v", "f_zero_s", "f_zero_v", "f_zero_fit_x_0")

    fig, axes = plt.subplots(len(names), len(par_effect_data_list), figsize=(32, 4*len(par_effect_data_list)), sharex="col")

    for j in range(len(par_effect_data_list)):
        ped = par_effect_data_list[j]

        ranges = ((0, max(ped.get_data("bf")) * 1.1), (0, 1), (-1, 1), (0, 1),
                  (0, max(ped.get_data("f_inf_s")) * 1.1), (0, 800),
                  (0, max(ped.get_data("f_zero_s")) * 1.1), (0, 10000), (-0.5, max(ped.get_data("f_zero_fit_x_0"))))
        values = ped.values

        for i in range(len(names)):
            y_data = ped.get_data(names[i])
            axes[i, j].plot(values, y_data)

            if names[i] == "f_zero_v":
                axes[i, j].set_yscale('log')
                axes[i, j].set_ylim(ranges[i])
            else:
                axes[i, j].set_ylim(ranges[i])

            if j == 0:
                axes[i, j].set_ylabel(names[i])

            if i == 0:
                axes[i, j].set_title(ped.test_parameter)

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path + "variable_effect_master_plot.png")
    else:
        plt.show()
    plt.close()


class ParameterEffectData:
    data_names = ("bf", "vs", "sc", "cv", "f_inf_s", "f_inf_v" "f_zero_s", "f_zero_v", "f_zero_fit_x_0")

    def __init__(self, fi_curves, values, test_parameter, bf, vs, sc, cv, f_inf_s, f_inf_v, f_zero_s, f_zero_v):
        self.fi_curves = fi_curves
        self.values = values
        self.test_parameter = test_parameter
        self.bf = bf
        self.vs = vs
        self.sc = sc
        self.cv = cv
        self.f_inf_s = f_inf_s
        self.f_inf_v = f_inf_v
        self.f_zero_s = f_zero_s
        self.f_zero_v = f_zero_v

    def get_data(self, name):
        if name == "bf":
            return self.bf
        elif name == "vs":
            return self.vs
        elif name == "sc":
            return self.sc
        elif name == "cv":
            return self.cv
        elif name == "f_inf_s":
            return self.f_inf_s
        elif name == "f_inf_v":
            return self.f_inf_v
        elif name == "f_zero_s":
            return self.f_zero_s
        elif name == "f_zero_v":
            return self.f_zero_v
        elif name == "f_zero_fit_x_0":
            fits = [fi.f_zero_fit for fi in self.fi_curves]
            x_zeros = [fit[3] for fit in fits]
            return x_zeros
        else:
            raise ValueError("Unknown attribute name!")


if __name__ == '__main__':
    main()