from models.LIFACnoise import LifacNoiseModel
from CellData import CellData, icelldata_of_dir
from FiCurve import FICurve
from AdaptionCurrent import Adaption
from stimuli.SinusAmplitudeModulation import SinusAmplitudeModulationStimulus
import helperFunctions as hF
import numpy as np
from scipy.optimize import curve_fit, minimize
import functions as fu
import time
import matplotlib.pyplot as plt


def main():
    # run_test_with_fixed_model()
    # quit()

    fitter = Fitter()
    fmin, params = fitter.fit_model_to_values(700, 1400, [-0.3], 1, [0.7, 0.8, 0.9, 1, 1.1, 1.2, 1.3], [1370, 1380, 1390, 1400, 1410, 1420, 1430], 100, 0.02, 0.01)

    print("calculated parameters:")
    print(params)

def run_with_real_data():
    for celldata in icelldata_of_dir("./data/"):
        start_time = time.time()
        fitter = Fitter(celldata)
        fmin, parameters = fitter.fit_model_to_data()

        print(fmin)
        print(parameters)
        end_time = time.time()

        print('Fitting of cell took function took {:.3f} s'.format((end_time - start_time)))

    pass


def run_test_with_fixed_model():
    a_tau = 10
    a_delta = 0.08

    parameters = {'mem_tau': 5, 'delta_a': a_delta, 'input_scaling': 100,
                  'v_offset': 80, 'threshold': 1, 'v_base': 0, 'step_size': 0.00005, 'tau_a': a_tau,
                  'a_zero': 0, 'v_zero': 0, 'noise_strength': 0.5}

    model = LifacNoiseModel(parameters)
    eod_freq = 750
    contrasts = np.arange(0.5, 1.51, 0.1)
    modulation_freq = 10
    baseline_freq, vector_strength, serial_correlation = model.calculate_baseline_markers(eod_freq)
    f_infinities, f_infinities_slope = model.calculate_fi_markers(contrasts, eod_freq, modulation_freq)

    fitter = Fitter()
    fmin, fit_parameters = fitter.fit_model_to_values(eod_freq, baseline_freq, serial_correlation, vector_strength, contrasts, f_infinities, f_infinities_slope, a_delta, a_tau)
    print("calculated parameters:")
    print(fit_parameters)

    print("ref parameters:")
    print(parameters)


class Fitter:

    def __init__(self, step_size=None):
        if step_size is not None:
            self.model = LifacNoiseModel({"step_size": step_size})
        else:
            self.model = LifacNoiseModel({"step_size": 0.0005})
        # self.data = data
        self.fi_contrasts = []
        self.eod_freq = 0

        self.modulation_frequency = 10
        self.sc_max_lag = 1

        # expected values the model has to replicate
        self.baseline_freq = 0
        self.vector_strength = -1
        self.serial_correlation = []

        self.f_infinities = []
        self.f_infinities_slope = 0

        # fixed values needed to fit model
        self.a_tau = 0
        self.a_delta = 0

        self.counter = 0

    def calculate_needed_values_from_data(self, data: CellData):
        self.eod_freq = data.get_eod_frequency()

        self.baseline_freq = data.get_base_frequency()
        self.vector_strength = data.get_vector_strength()
        self.serial_correlation = data.get_serial_correlation(self.sc_max_lag)

        fi_curve = FICurve(data, contrast=True)
        self.fi_contrasts = fi_curve.stimulus_value
        self.f_infinities = fi_curve.f_infinities
        self.f_infinities_slope = fi_curve.get_f_infinity_slope()

        f_zero_slope = fi_curve.get_fi_curve_slope_of_straight()
        self.a_delta = f_zero_slope / self.f_infinities_slope

        adaption = Adaption(data, fi_curve)
        self.a_tau = adaption.get_tau_real()

    # mem_tau, (threshold?), (v_offset), noise_strength, input_scaling
    def cost_function(self, X, tau_a=10, delta_a=3, error_scaling=()):
        freq_sampling_rate = 0.005
        # set model parameters to the given ones:
        self.model.set_variable("mem_tau", X[0])
        self.model.set_variable("noise_strength", X[1])
        self.model.set_variable("input_scaling", X[2])
        self.model.set_variable("tau_a", tau_a)
        self.model.set_variable("delta_a", delta_a)

        # minimize the difference in baseline_freq first by fitting v_offset
        # v_offset = self.__fit_v_offset_to_baseline_frequency__()
        base_stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, 0, 0)

        v_offset = self.model.find_v_offset(self.baseline_freq, base_stimulus)
        self.model.set_variable("v_offset", v_offset)

        # only eod with amplitude 1 and no modulation
        _, spiketimes = self.model.simulate_fast(base_stimulus, 30)

        baseline_freq = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, 5)
        # print("model:", baseline_freq, "data:", self.baseline_freq)

        relative_spiketimes = np.array([s % (1/self.eod_freq) for s in spiketimes if s > 0])
        eod_durations = np.full((len(relative_spiketimes)), 1/self.eod_freq)
        vector_strength = hF.__vector_strength__(relative_spiketimes, eod_durations)
        serial_correlation = hF.calculate_serial_correlation(np.array(spiketimes), self.sc_max_lag)

        f_infinities = []
        for contrast in self.fi_contrasts:
            stimulus = SinusAmplitudeModulationStimulus(self.eod_freq, contrast, self.modulation_frequency)
            _, spiketimes = self.model.simulate_fast(stimulus, 1)

            if len(spiketimes) < 2:
                f_infinities.append(0)
            else:
                f_infinity = hF.mean_freq_of_spiketimes_after_time_x(spiketimes, 0.5)
                f_infinities.append(f_infinity)

        popt, pcov = curve_fit(fu.line, self.fi_contrasts, f_infinities, maxfev=10000)

        f_infinities_slope = popt[0]

        error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq)
        error_vs = abs((vector_strength - self.vector_strength) / self.vector_strength)
        error_sc = abs((serial_correlation[0] - self.serial_correlation[0]) / self.serial_correlation[0])
        error_f_inf_slope = abs((f_infinities_slope - self.f_infinities_slope) / self.f_infinities_slope)
        #print("vs:", vector_strength, self.vector_strength)
        #print("sc", serial_correlation[0], self.serial_correlation[0])
        #print("f slope:", f_infinities_slope, self.f_infinities_slope)
        error_f_inf = 0
        for i in range(len(f_infinities)):
            error_f_inf += abs((f_infinities[i] - self.f_infinities[i]) / f_infinities[i])

        error_f_inf = error_f_inf / len(f_infinities)
        self.counter += 1
        # print("mem_tau:", X[0], "noise:", X[0], "input_scaling:", X[2])
        errors = [error_bf, error_vs, error_sc, error_f_inf_slope, error_f_inf]
        print("Cost function run times:", self.counter, "error sum:", sum(errors), errors)
        return error_bf + error_vs + error_sc + error_f_inf_slope + error_f_inf

    def fit_model_to_data(self, data: CellData):
        self.calculate_needed_values_from_data(data)
        return self.fit_model()

    def fit_model_to_values(self, eod_freq, baseline_freq, sc, vs, fi_contrasts, fi_inf_values, fi_inf_slope, a_delta, a_tau):
        self.eod_freq = eod_freq
        self.baseline_freq = baseline_freq
        self.serial_correlation = sc
        self.vector_strength = vs
        self.fi_contrasts = fi_contrasts
        self.f_infinities = fi_inf_values
        self.f_infinities_slope = fi_inf_slope
        self.a_delta = a_delta
        self.a_tau = a_tau

        return self.fit_model()

    def fit_model(self):
        x0 = np.array([20, 15, 75])
        init_simplex = np.array([np.array([2, 1, 10]), np.array([40, 100, 140]), np.array([20, 50, 70]), np.array([150, 1, 200])])
        fmin = minimize(fun=self.cost_function, x0=x0, args=(self.a_tau, self.a_delta), method="Nelder-Mead", options={"initial_simplex": init_simplex})

        return fmin, self.model.get_parameters()


if __name__ == '__main__':
    main()