from models.LIFACnoise import LifacNoiseModel
from stimuli.SinusoidalStepStimulus import SinusoidalStepStimulus
from CellData import CellData
from Baseline import get_baseline_class
from FiCurve import get_fi_curve_class
from AdaptionCurrent import Adaption
import numpy as np
from warnings import warn
from scipy.optimize import minimize
import time
from helperFunctions import plot_errors

import matplotlib.pyplot as plt


class Fitter:

    def __init__(self):

        self.base_model = LifacNoiseModel({"step_size": 0.00005})

        self.best_parameters_found = []
        self.smallest_error = np.inf

        #
        self.fi_contrasts = []
        self.recording_times = []
        self.eod_freq = 0
        self.data_sampling_interval = -1

        self.sc_max_lag = 2

        # values to be replicated:
        self.isi_bins = np.array(0)
        self.baseline_freq = 0
        self.vector_strength = -1
        self.serial_correlation = []
        self.coefficient_of_variation = 0
        self.burstiness = -1

        self.f_inf_values = []
        self.f_inf_slope = 0

        self.f_zero_values = []
        # self.f_zero_slopes = []
        self.f_zero_slope_at_straight = 0
        self.f_zero_straight_contrast = 0
        self.f_zero_fit = []
        self.f_zero_curve_contrast = 0
        self.f_zero_curve_contrast_idx = -1
        self.f_zero_curve_freq = np.array([])
        self.f_zero_curve_time = np.array([])

        self.errors = []

        # counts how often the cost_function was called
        self.counter = 0

    def set_data_reference_values(self, cell_data: CellData):
        self.eod_freq = cell_data.get_eod_frequency()
        self.data_sampling_interval = cell_data.get_sampling_interval()
        self.recording_times = cell_data.get_recording_times()

        data_baseline = get_baseline_class(cell_data)
        data_baseline.load_values(cell_data.get_data_path())
        self.baseline_freq = data_baseline.get_baseline_frequency()
        self.isi_bins = calculate_histogram_bins(data_baseline.get_interspike_intervals())
        # plt.close()
        # plt.plot(self.isi_bins)
        # plt.show()
        # plt.close()
        self.vector_strength = data_baseline.get_vector_strength()
        self.serial_correlation = data_baseline.get_serial_correlation(self.sc_max_lag)
        self.coefficient_of_variation = data_baseline.get_coefficient_of_variation()
        self.burstiness = data_baseline.get_burstiness()

        contrasts = np.array(cell_data.get_fi_contrasts())
        fi_curve = get_fi_curve_class(cell_data, contrasts, save_dir=cell_data.get_data_path())
        self.f_inf_slope = fi_curve.get_f_inf_slope()

        if self.f_inf_slope < 0:
            contrasts = contrasts * -1
            # print("old contrasts:", cell_data.get_fi_contrasts())
            # print("new contrasts:", contrasts)

            fi_curve = get_fi_curve_class(cell_data, contrasts, save_dir=cell_data.get_data_path())

        self.fi_contrasts = fi_curve.stimulus_values
        self.f_inf_values = fi_curve.f_inf_frequencies
        self.f_inf_slope = fi_curve.get_f_inf_slope()

        self.f_zero_values = fi_curve.f_zero_frequencies
        self.f_zero_fit = fi_curve.f_zero_fit
        # self.f_zero_slopes = [fi_curve.get_f_zero_fit_slope_at_stimulus_value(c) for c in self.fi_contrasts]
        self.f_zero_slope_at_straight = fi_curve.get_f_zero_fit_slope_at_straight()
        self.f_zero_slope_at_zero = fi_curve.get_f_zero_fit_slope_at_stimulus_value(0)
        self.f_zero_straight_contrast = self.f_zero_fit[3]

        max_contrast = max(contrasts)
        test_contrast = 0.5 * max_contrast
        diff_contrasts = np.abs(contrasts - test_contrast)

        self.f_zero_curve_contrast_idx = np.argmin(diff_contrasts)
        self.f_zero_curve_contrast = contrasts[self.f_zero_curve_contrast_idx]
        times, freqs = fi_curve.get_mean_time_and_freq_traces()
        self.f_zero_curve_freq = freqs[self.f_zero_curve_contrast_idx]
        self.f_zero_curve_time = times[self.f_zero_curve_contrast_idx]

        # around 1/3 of the value at straight
        # self.f_zero_slope = fi_curve.get_fi_curve_slope_at(fi_curve.get_f_zero_and_f_inf_intersection())

        # adaption = Adaption(fi_curve)
        # self.tau_a = adaption.get_tau_real()

    def fit_model_to_data(self, data: CellData, start_parameters, fit_routine_func: callable):
        self.set_data_reference_values(data)
        return fit_routine_func(start_parameters)

    def fit_routine(self, start_parameters, error_weights=None):
        self.counter = 0
        # fit only v_offset, mem_tau, input_scaling, dend_tau

        x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"],
                       start_parameters["input_scaling"], start_parameters["tau_a"], start_parameters["delta_a"],
                       start_parameters["dend_tau"], start_parameters["refractory_period"]])
        initial_simplex = create_init_simples(x0, search_scale=3)

        # error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty,
        #               error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]

        fmin = minimize(fun=self.cost_function_all,
                        args=(error_weights,), x0=x0, method="Nelder-Mead",
                        options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400})

        return fmin, self.base_model.get_parameters()

    def cost_function_all(self, X, error_weights=None):
        # tau mins:
        tau_min = 0.001
        for i in (0, 3, 5):
            if X[i] < tau_min:
                print("tried too small tau value")
                return 1000 + abs(X[i] - tau_min) * 10000

        for i in (1, 2, 4, 6):
            if X[i] < 0:
                print("tried negative parameter value")
                return 1000 + abs(X[i]) * 10000

        if X[6] > 1.05/self.eod_freq:  # refractory period shouldn't be larger than one eod period
            print("tried too large ref period")
            return 1000 + abs(X[6]) * 10000
        self.base_model.set_variable("mem_tau", X[0])
        self.base_model.set_variable("noise_strength", X[1])
        self.base_model.set_variable("input_scaling", X[2])
        self.base_model.set_variable("tau_a", X[3])
        self.base_model.set_variable("delta_a", X[4])
        self.base_model.set_variable("dend_tau", X[5])
        self.base_model.set_variable("refractory_period", X[6])

        base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0)
        # find right v-offset
        test_model = self.base_model.get_model_copy()
        test_model.set_variable("noise_strength", 0)

        # time1 = time.time()
        v_offset = test_model.find_v_offset(self.baseline_freq, base_stimulus)
        self.base_model.set_variable("v_offset", v_offset)
        # time2 = time.time()
        # print("time taken for finding v_offset: {:.2f}s".format(time2-time1))

        error_list = self.calculate_errors(error_weights)
        # print("sum: {:.2f}, ".format(sum(error_list)))
        if sum(error_list) < self.smallest_error:
            self.smallest_error = sum(error_list)
            self.best_parameters_found = X
        return sum(error_list)

    def fit_routine_no_dend_tau(self, start_parameters, error_weights=None):
        self.counter = 0
        # fit all except dend_tau
        self.base_model.parameters["dend_tau"] = 0

        x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"],
                       start_parameters["input_scaling"], start_parameters["tau_a"],
                       start_parameters["delta_a"], start_parameters["refractory_period"]])
        initial_simplex = create_init_simples(x0, search_scale=3)

        # error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty,
        #               error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]

        fmin = minimize(fun=self.cost_function_no_dend_tau,
                        args=(error_weights,), x0=x0, method="Nelder-Mead",
                        options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400})

        return fmin, self.base_model.get_parameters()

    def cost_function_no_dend_tau(self, X, error_weights=None):
        # tau mins:
        tau_min = 0.001
        for i in (0, 3):
            if X[i] < tau_min:
                print("tried too small tau value")
                return 1000 + abs(X[i] - tau_min) * 10000

        for i in (1, 2, 4, 5):
            if X[i] < 0:
                print("tried negative parameter value")
                return 1000 + abs(X[i]) * 10000

        if X[5] > 1.05/self.eod_freq:  # refractory period shouldn't be larger than one eod period
            print("tried too large ref period")
            return 1000 + abs(X[5]) * 10000
        self.base_model.set_variable("mem_tau", X[0])
        self.base_model.set_variable("noise_strength", X[1])
        self.base_model.set_variable("input_scaling", X[2])
        self.base_model.set_variable("tau_a", X[3])
        self.base_model.set_variable("delta_a", X[4])
        self.base_model.set_variable("refractory_period", X[5])

        base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0)
        # find right v-offset
        test_model = self.base_model.get_model_copy()
        test_model.set_variable("noise_strength", 0)

        # time1 = time.time()
        v_offset = test_model.find_v_offset(self.baseline_freq, base_stimulus)
        self.base_model.set_variable("v_offset", v_offset)



        # time2 = time.time()
        # print("time taken for finding v_offset: {:.2f}s".format(time2-time1))

        error_list = self.calculate_errors(error_weights)
        # print("sum: {:.2f}, ".format(sum(error_list)))
        if sum(error_list) < self.smallest_error:
            self.smallest_error = sum(error_list)
            self.best_parameters_found = X
        return sum(error_list)

    def fit_routine_no_ref_period(self, start_parameters, error_weights=None):
        self.counter = 0
        # fit all except ref_period

        self.base_model.set_variable("refractory_period", 0)

        x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"],
                       start_parameters["input_scaling"], start_parameters["tau_a"], start_parameters["delta_a"],
                       start_parameters["dend_tau"]])
        initial_simplex = create_init_simples(x0, search_scale=3)

        # error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty,
        #               error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]

        fmin = minimize(fun=self.cost_function_no_ref_period,
                        args=(error_weights,), x0=x0, method="Nelder-Mead",
                        options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400})

        return fmin, self.base_model.get_parameters()

    def cost_function_no_ref_period(self, X, error_weights=None):
        # tau mins:
        tau_min = 0.001
        for i in (0, 3, 5):
            if X[i] < tau_min:
                print("tried too small tau value")
                return 1000 + abs(X[i] - tau_min) * 10000

        for i in (1, 2, 4):
            if X[i] < 0:
                print("tried negative parameter value")
                return 1000 + abs(X[i]) * 10000

        self.base_model.set_variable("mem_tau", X[0])
        self.base_model.set_variable("noise_strength", X[1])
        self.base_model.set_variable("input_scaling", X[2])
        self.base_model.set_variable("tau_a", X[3])
        self.base_model.set_variable("delta_a", X[4])
        self.base_model.set_variable("dend_tau", X[5])

        base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0)
        # find right v-offset
        test_model = self.base_model.get_model_copy()
        test_model.set_variable("noise_strength", 0)

        # time1 = time.time()
        v_offset = test_model.find_v_offset(self.baseline_freq, base_stimulus)
        self.base_model.set_variable("v_offset", v_offset)
        # time2 = time.time()
        # print("time taken for finding v_offset: {:.2f}s".format(time2-time1))

        error_list = self.calculate_errors(error_weights)
        # print("sum: {:.2f}, ".format(sum(error_list)))
        if sum(error_list) < self.smallest_error:
            self.smallest_error = sum(error_list)
            self.best_parameters_found = X
        return sum(error_list)

    def fit_routine_no_dend_tau_and_no_ref_period(self, start_parameters, error_weights=None):
        self.counter = 0
        # fit all except dend_tau and ref_period
        self.base_model.parameters["refractory_period"] = 0
        self.base_model.parameters["dend_tau"] = 0

        x0 = np.array([start_parameters["mem_tau"], start_parameters["noise_strength"],
                       start_parameters["input_scaling"], start_parameters["tau_a"], start_parameters["delta_a"]])
        initial_simplex = create_init_simples(x0, search_scale=3)

        # error_list = [error_bf, error_vs, error_sc, error_cv, error_bursty,
        #               error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]

        fmin = minimize(fun=self.cost_function_no_dend_tau_and_no_ref_period,
                        args=(error_weights,), x0=x0, method="Nelder-Mead",
                        options={"initial_simplex": initial_simplex, "xatol": 0.001, "maxfev": 600, "maxiter": 400})

        return fmin, self.base_model.get_parameters()

    def cost_function_no_dend_tau_and_no_ref_period(self, X, error_weights=None):
        # tau mins:
        tau_min = 0.001
        for i in (0, 3):
            if X[i] < tau_min:
                print("tried too small tau value")
                return 1000 + abs(X[i] - tau_min) * 10000

        for i in (1, 2, 4):
            if X[i] < 0:
                print("tried negative parameter value")
                return 1000 + abs(X[i]) * 10000

        self.base_model.set_variable("mem_tau", X[0])
        self.base_model.set_variable("noise_strength", X[1])
        self.base_model.set_variable("input_scaling", X[2])
        self.base_model.set_variable("tau_a", X[3])
        self.base_model.set_variable("delta_a", X[4])

        base_stimulus = SinusoidalStepStimulus(self.eod_freq, 0)
        # find right v-offset
        test_model = self.base_model.get_model_copy()
        test_model.set_variable("noise_strength", 0)

        # time1 = time.time()
        v_offset = test_model.find_v_offset(self.baseline_freq, base_stimulus)
        self.base_model.set_variable("v_offset", v_offset)
        # time2 = time.time()
        # print("time taken for finding v_offset: {:.2f}s".format(time2-time1))

        error_list = self.calculate_errors(error_weights)
        # print("sum: {:.2f}, ".format(sum(error_list)))
        if sum(error_list) < self.smallest_error:
            self.smallest_error = sum(error_list)
            self.best_parameters_found = X
        return sum(error_list)

    def calculate_errors(self, error_weights=None, model=None):
        if model is None:
            model = self.base_model

        # time1 = time.time()
        model_baseline = get_baseline_class(model, self.eod_freq, trials=3)
        baseline_freq = model_baseline.get_baseline_frequency()
        vector_strength = model_baseline.get_vector_strength()
        serial_correlation = model_baseline.get_serial_correlation(self.sc_max_lag)
        coefficient_of_variation = model_baseline.get_coefficient_of_variation()
        burstiness = model_baseline.get_burstiness()
        # time2 = time.time()
        isi_bins = calculate_histogram_bins(model_baseline.get_interspike_intervals())
        # print("Time taken for all baseline parameters: {:.2f}".format(time2-time1))

        # time1 = time.time()
        fi_curve_model = get_fi_curve_class(model, self.fi_contrasts, self.eod_freq, trials=8)
        f_zeros = fi_curve_model.get_f_zero_frequencies()
        f_infinities = fi_curve_model.get_f_inf_frequencies()
        f_infinities_slope = fi_curve_model.get_f_inf_slope()
        # f_zero_slopes = [fi_curve_model.get_f_zero_fit_slope_at_stimulus_value(x) for x in self.fi_contrasts]
        f_zero_slope_at_straight = fi_curve_model.get_f_zero_fit_slope_at_stimulus_value(self.f_zero_straight_contrast)

        # time2 = time.time()

        # print("Time taken for all fi-curve parameters: {:.2f}".format(time2 - time1))

        # calculate errors with reference values
        error_bf = abs((baseline_freq - self.baseline_freq) / self.baseline_freq)
        error_vs = abs((vector_strength - self.vector_strength) / 0.01)
        error_cv = abs((coefficient_of_variation - self.coefficient_of_variation) / 0.05)
        error_bursty = (abs(burstiness - self.burstiness) / 0.2)
        error_hist = np.sqrt(np.mean((isi_bins - self.isi_bins) ** 2)) / 10
        # print("error hist: {:.2f}".format(error_hist))
        # print("Burstiness: cell {:.2f}, model: {:.2f}, error: {:.2f}".format(self.burstiness, burstiness, error_bursty))

        error_sc = 0
        for i in range(self.sc_max_lag):
            error_sc += abs((serial_correlation[i] - self.serial_correlation[i]) / 0.1)
        # error_sc = error_sc / self.sc_max_lag

        error_f_inf_slope = abs((f_infinities_slope - self.f_inf_slope) / abs(self.f_inf_slope+1)) * 25
        error_f_inf = calculate_list_error(f_infinities, self.f_inf_values)

        # error_f_zero_slopes = calculate_list_error(f_zero_slopes, self.f_zero_slopes)
        error_f_zero_slope_at_straight = abs(self.f_zero_slope_at_straight - f_zero_slope_at_straight) \
                                            / abs(self.f_zero_slope_at_straight+1)
        error_f_zero = calculate_list_error(f_zeros, self.f_zero_values) / 10

        error_f0_curve = self.calculate_f0_curve_error(model, fi_curve_model) / 20

        error_list = [error_vs, error_sc, error_cv, error_hist, error_bursty,
                      error_f_inf, error_f_inf_slope, error_f_zero, error_f_zero_slope_at_straight, error_f0_curve]

        self.errors.append(error_list)
        if error_weights is not None and len(error_weights) == len(error_list):
            for i in range(len(error_weights)):
                error_list[i] = error_list[i] * error_weights[i]
        elif error_weights is not None:
            warn("Error: weights had different length than errors and were ignored!")
        if sum(error_list) < 0:
            print("Error negative: ", error_list)
        if np.isnan(sum(error_list)):
            print("--------SOME ERROR VALUE(S) IS/ARE NaN:")
            print(error_list)
            return [50 for e in error_list]
            # raise ValueError("Some error value(s) is/are NaN!")
        return error_list

    def calculate_f0_curve_error(self, model, fi_curve_model, dendritic_delay=0.005):
        buffer = 0.00
        test_duration = 0.05

        # prepare model frequency curve:
        times, freqs = fi_curve_model.get_mean_time_and_freq_traces()
        freq_prediction = np.array(freqs[self.f_zero_curve_contrast_idx])
        time_prediction = np.array(times[self.f_zero_curve_contrast_idx])

        if len(time_prediction) == 0:
            return 200
        stimulus_start = fi_curve_model.get_stimulus_start() - time_prediction[0]

        model_start_idx = int((stimulus_start - buffer) / fi_curve_model.get_sampling_interval())
        model_end_idx = int((stimulus_start + buffer + test_duration) / model.get_sampling_interval())

        idx_offset = int(dendritic_delay / model.get_sampling_interval())

        model_start_idx += idx_offset
        model_end_idx += idx_offset

        if len(time_prediction) == 0 or len(time_prediction) < model_end_idx \
                or time_prediction[0] > fi_curve_model.get_stimulus_start():
            error_f0_curve = 200
            return error_f0_curve

        model_curve = freq_prediction[model_start_idx:model_end_idx]

        # prepare cell frequency_curve:

        stimulus_start = self.recording_times[1] - self.f_zero_curve_time[0]
        cell_start_idx = int((stimulus_start - buffer) / self.data_sampling_interval)
        cell_end_idx = int((stimulus_start + buffer + test_duration) / self.data_sampling_interval)

        if round(model.get_sampling_interval() % self.data_sampling_interval, 4) == 0:
            step_cell = int(round(model.get_sampling_interval() / self.data_sampling_interval))
        else:
            raise ValueError("Model sampling interval is not a multiple of data sampling interval.")

        cell_curve = self.f_zero_curve_freq[cell_start_idx:cell_end_idx:step_cell]
        # plt.close()
        # plt.plot(cell_curve)
        # plt.plot(model_curve)
        # plt.savefig("./figures/f_zero_curve_error_{}.png".format(time.strftime("%H:%M:%S")))
        # plt.close()

        if len(cell_curve) < len(model_curve):
            model_curve = model_curve[:len(cell_curve)]
        elif len(model_curve) < len(cell_curve):
            cell_curve = cell_curve[:len(model_curve)]

        error_f0_curve = np.sqrt(np.mean((model_curve - cell_curve) ** 2))

        return error_f0_curve

    # def calculate_f0_curve_error_new(self, model, fi_curve_model):
    #     buffer = 0.05
    #     test_duration = 0.05
    #
    #     times, freqs = fi_curve_model.get_mean_time_and_freq_traces()
    #     freq_prediction = np.array(freqs[self.f_zero_curve_contrast_idx])
    #     time_prediction = np.array(times[self.f_zero_curve_contrast_idx])
    #
    #     if len(time_prediction) == 0:
    #         return 200
    #     stimulus_start = fi_curve_model.get_stimulus_start() - time_prediction[0]
    #
    #     model_start_idx = int((stimulus_start - buffer) / model.get_sampling_interval())
    #     model_end_idx = int((stimulus_start + buffer + test_duration) / model.get_sampling_interval())
    #
    #     if len(time_prediction) == 0 or len(time_prediction) < model_end_idx \
    #             or time_prediction[0] > fi_curve_model.get_stimulus_start():
    #         error_f0_curve = 200
    #         return error_f0_curve
    #
    #     model_curve = np.array(freq_prediction[model_start_idx:model_end_idx])
    #
    #     # prepare cell frequency_curve:
    #
    #     stimulus_start = self.recording_times[1] - self.f_zero_curve_time[0]
    #     cell_start_idx = int((stimulus_start - buffer) / self.data_sampling_interval)
    #     cell_end_idx = int((stimulus_start - buffer + test_duration) / self.data_sampling_interval)
    #
    #     if round(model.get_sampling_interval() % self.data_sampling_interval, 4) == 0:
    #         step_cell = int(round(model.get_sampling_interval() / self.data_sampling_interval))
    #     else:
    #         raise ValueError("Model sampling interval is not a multiple of data sampling interval.")
    #
    #     cell_curve = self.f_zero_curve_freq[cell_start_idx:cell_end_idx:step_cell]
    #     cell_time = self.f_zero_curve_time[cell_start_idx:cell_end_idx:step_cell]
    #     cell_curve_std = np.std(self.f_zero_curve_freq)
    #     model_curve_std = np.std(freq_prediction)
    #
    #     model_limit = self.baseline_freq + model_curve_std
    #     cell_limit = self.baseline_freq + cell_curve_std
    #
    #     cell_full_precicion = np.array(self.f_zero_curve_freq[cell_start_idx:cell_end_idx])
    #     cell_points_above = cell_full_precicion > cell_limit
    #     cell_area_above = sum(cell_full_precicion[cell_points_above]) * self.data_sampling_interval
    #
    #     model_points_above = model_curve > model_limit
    #     model_area_above = sum(model_curve[model_points_above]) * model.get_sampling_interval()
    #
    #     # plt.close()
    #     # plt.plot(cell_time, cell_curve, color="blue")
    #     # plt.plot((cell_time[0], cell_time[-1]), (cell_limit, cell_limit),
    #     #          color="lightblue", label="area: {:.2f}".format(cell_area_above))
    #     #
    #     # plt.plot(time_prediction[model_start_idx:model_end_idx], model_curve, color="orange")
    #     # plt.plot((time_prediction[model_start_idx], time_prediction[model_end_idx]), (model_limit, model_limit),
    #     #          color="red", label="area: {:.2f}".format(model_area_above))
    #     # plt.legend()
    #     # plt.title("Error: {:.2f}".format(abs(model_area_above - cell_area_above) / 0.02))
    #     # plt.savefig("./figures/f_zero_curve_error_{}.png".format(time.strftime("%H:%M:%S")))
    #     # plt.close()
    #
    #     return abs(model_area_above - cell_area_above)


def calculate_list_error(fit, reference):
    error = 0
    for i in range(len(reference)):
        # error += abs_freq_error(fit[i] - reference[i])
        error += normed_quadratic_freq_error(fit[i], reference[i])
    norm_error = error / len(reference)

    return norm_error


def calculate_histogram_bins(isis):
    isis = np.array(isis) * 1000
    step = 0.1
    bins = np.arange(0, 50, step)

    counts = np.array([np.sum((isis >= b) & (isis < b+0.1)) for b in bins])
    return counts


def normed_quadratic_freq_error(fit, ref, factor=2):
    return (abs(fit-ref)/factor)**2 / ref


def abs_freq_error(diff, factor=10):
    return abs(diff) / factor


def create_init_simples(x0, search_scale=3.):
    dim = len(x0)
    simplex = [[x0[0]/search_scale], [x0[0]*search_scale]]
    for i in range(1, dim, 1):
        for vertex in simplex:
            vertex.append(x0[i]*search_scale)
        new_vertex = list(x0[:i])
        new_vertex.append(x0[i]/search_scale)
        simplex.append(new_vertex)

    return simplex


if __name__ == '__main__':
    print("use run_fitter.py to run the Fitter.")