import numpy as np
from scipy.optimize import curve_fit


def safe_get_val(dictionary:dict, key, default=None):
    return dictionary[key] if key in dictionary.keys() else default


def results_check(results, id, text="ID"):
    if len(results) == 0:
        raise ValueError("%s %s does not exist!" % (text, id))
    elif len(results) > 1:
        raise ValueError("%s %s is not unique!" % (text, id))


def zero_crossings(x, t, interpolate=False):
    """get the times at which a signal x 

    Args:
        x ([type]): [description]
        t ([type]): [description]
        interpolate (bool, optional): [description]. Defaults to False.

    Returns:
        [type]: [description]
    """
    dt = t[1] - t[0]
    x_shift = np.roll(x, 1)
    x_shift[0] = 0.0
    xings = np.where((x >= 0 ) & (x_shift < 0))[0]
    crossings = np.zeros(len(xings))
    if interpolate:
        for i, tf in enumerate(xings):
            if x[tf] > 0.001:
                m = (x[tf] - x[tf-1])/dt
                crossings[i] = t[tf] - x[tf]/m
            elif x[tf] < -0.001:
                m = (x[tf + 1] - x[tf]) / dt
                crossings[i] = t[tf] - x[tf]/m
            else:
                crossings[i] = t[tf]
    else:
        crossings = t[xings]
    return crossings


def unzip_if_needed(dataset, tracename='trace-1.raw'):
    file_name = os.path.join(dataset, tracename)
    if os.path.exists(file_name):
        return
    if os.path.exists(file_name + '.gz'):
        print("\tunzip: %s" % tracename)
        subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])


def gaussian_kernel(sigma, dt):
    x = np.arange(-4. * sigma, 4. * sigma, dt)
    y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
    return y



class BoltzmannFit:
    """
    Class representing a fit of a Boltzmann function to some data.
    """

    def __init__(self, xvalues: np.ndarray, yvalues: np.ndarray, initial_params=None):
        """
        Constructor. Takes the x and the y data and tries to fit a Boltzmann to it.

        :param xvalues: numpy array of x (e.g. contrast) values
        :param yvalues: numpy array of y (e.g. firing rate) values
        :param initial_params: list of initial parameters, default None to autogenerate
        """
        assert(len(xvalues) == len(yvalues))
        self.__xvals = xvalues
        self.__yvals = yvalues
        self.__fit_params = None
        self.__initial_params = initial_params
        self.__x_sorted = np.unique(self.__xvals)
        self.__y_avg = None
        self.__y_err = None
        self.__do_fit()

    @staticmethod
    def boltzmann(x, y_max, slope, inflection):
        """
        The underlying Boltzmann function.
        .. math::
            f(x) = y_max / \exp{-slope*(x-inflection}

        :param x: The x values.
        :param y_max: The maximum value.
        :param slope: The slope parameter k
        :param inflection: the position of the inflection point.
        :return: the y values.
        """
        y = y_max / (1 + np.exp(-slope * (x - inflection)))
        return y

    def __do_fit(self):
        self.__y_avg = np.zeros(self.__x_sorted.shape)
        self.__y_err = np.zeros(self.__x_sorted.shape)
        for i, c in enumerate(self.__x_sorted):
            self.__y_avg[i] = np.mean(self.__yvals[self.__xvals == c])
            self.__y_err[i] = np.std(self.__yvals[self.__xvals == c])
        if self.__initial_params:
            p = self.__initial_params
        else:
            p = [np.max(self.__y_avg), 0, 0]
        self.__fit_params, _ = curve_fit(self.boltzmann, self.__x_sorted, self.__y_avg, p)

    @property
    def slope(self) -> float:
        r"""
        The slope of the linear part of the Boltzmann, i.e.
        .. math::
            s = f_max $\cdot$ k / 4
        :return: the slope.
        """
        return self.__fit_params[0] * self.__fit_params[1] / 4

    @property
    def parameters(self):
        """ fit parameters
        :return: The fit parameters.
        """
        return self.__fit_params

    @property
    def x_data(self):
        """ The x data sorted and unique used for fitting.
        :return: the x data
        """
        return self.__x_sorted

    @property
    def y_data(self):
        """
        the Y data used for fitting, i.e. the average rate in the specified time window sorted by the x data.
        :return: the average and the standard deviation of the y data
        """
        return self.__y_avg, self.__y_err

    def solve(self, xvalues=None):
        if not xvalues:
            xvalues = self.__x_sorted
        return self.boltzmann(xvalues, *self.__fit_params)