import numpy as np
import os
import subprocess
from scipy.optimize import curve_fit
from IPython import embed

def spike_times_to_rate(spike_times, time_axis, kernel_width=0.005):
    """Convert spike times to a rate by means of kernel convolution. A Gaussian kernel of the desired width is used.

    Args:
        spike_times (numpy.ndarray): the spike times in seconds.
        time_axis (np.ndarray): the time axis with a proper resolution and extent. (in seconds)
        kernel_width (float, optional): the standard deviation of the Gausian kernel. Defaults to 0.005.

    Returns:
        np.ndarray: the firing rate in Hz.
    """
    dt = np.mean(np.diff(time_axis))
    binary = np.zeros(time_axis.shape)
    spike_indices = ((spike_times - time_axis[0]) / dt).astype(int)
    binary[spike_indices[(spike_indices >= 0) & (spike_indices < len(binary))]] = 1
    g = gaussian_kernel(kernel_width, dt)
    rate = np.convolve(binary, g, mode='same')
    return rate


def safe_get_val(dictionary:dict, key, default=None):
    return dictionary[key] if key in dictionary.keys() else default


def results_check(results, id, text="ID"):
    if len(results) == 0:
        raise ValueError("%s %s does not exist!" % (text, id))
    elif len(results) > 1:
        raise ValueError("%s %s is not unique!" % (text, id))


def zero_crossings(x, t, interpolate=False):
    """get the times at which a signal x 

    Args:
        x ([type]): [description]
        t ([type]): [description]
        interpolate (bool, optional): [description]. Defaults to False.

    Returns:
        [type]: [description]
    """
    dt = t[1] - t[0]
    x_shift = np.roll(x, 1)
    x_shift[0] = 0.0
    xings = np.where((x >= 0 ) & (x_shift < 0))[0]
    crossings = np.zeros(len(xings))
    if interpolate:
        for i, tf in enumerate(xings):
            if x[tf] > 0.001:
                m = (x[tf] - x[tf-1])/dt
                crossings[i] = t[tf] - x[tf]/m
            elif x[tf] < -0.001:
                m = (x[tf + 1] - x[tf]) / dt
                crossings[i] = t[tf] - x[tf]/m
            else:
                crossings[i] = t[tf]
    else:
        crossings = t[xings]
    return crossings


def unzip_if_needed(dataset, tracename='trace-1.raw'):
    """[summary]

    Args:
        dataset ([type]): [description]
        tracename (str, optional): [description]. Defaults to 'trace-1.raw'.
    """
    file_name = os.path.join(dataset, tracename)
    if os.path.exists(file_name):
        return
    if os.path.exists(file_name + '.gz'):
        print("\tunzip: %s" % tracename)
        subprocess.check_call(["gunzip", os.path.join(dataset, tracename + ".gz")])


def gaussian_kernel(sigma, dt):
    """Creates a gaussian kernel with the integral of one.

    Args:
        sigma ([type]): [description]
        dt ([type]): [description]

    Returns:
        [type]: [description]
    """
    x = np.arange(-4. * sigma, 4. * sigma, dt)
    y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
    return y


class BoltzmannFit:
    """
    Class representing a fit of a Boltzmann function to some data.
    """

    def __init__(self, xvalues: np.ndarray, yvalues: np.ndarray, initial_params=None):
        """
        Constructor. Takes the x and the y data and tries to fit a Boltzmann to it.

        :param xvalues: numpy array of x (e.g. contrast) values
        :param yvalues: numpy array of y (e.g. firing rate) values
        :param initial_params: list of initial parameters, default None to autogenerate
        """
        assert(len(xvalues) == len(yvalues))
        self.__xvals = xvalues
        self.__yvals = yvalues
        self.__fit_params = None
        self.__initial_params = initial_params
        self.__x_sorted = np.unique(self.__xvals)
        self.__y_avg = None
        self.__y_err = None
        self.__do_fit()

    @staticmethod
    def boltzmann(x, y_max, slope, inflection):
        """
        The underlying Boltzmann function.
        .. math::
            f(x) = y_max / exp{-slope*(x-inflection}

        :param x: The x values.
        :param y_max: The maximum value.
        :param slope: The slope parameter k
        :param inflection: the position of the inflection point.
        :return: the y values.
        """
        y = y_max / (1 + np.exp(-slope * (x - inflection)))
        return y

    def __do_fit(self):
        self.__y_avg = np.zeros(self.__x_sorted.shape)
        self.__y_err = np.zeros(self.__x_sorted.shape)
        for i, c in enumerate(self.__x_sorted):
            self.__y_avg[i] = np.mean(self.__yvals[self.__xvals == c])
            self.__y_err[i] = np.std(self.__yvals[self.__xvals == c])
        if self.__initial_params:
            p = self.__initial_params
        else:
            p = [np.max(self.__y_avg), 0, 0]
        self.__fit_params, _ = curve_fit(self.boltzmann, self.__x_sorted, self.__y_avg, p)

    @property
    def slope(self) -> float:
        r"""
        The slope of the linear part of the Boltzmann, i.e.
        .. math::
            s = f_max $\cdot$ k / 4
        :return: the slope.
        """
        return self.__fit_params[0] * self.__fit_params[1] / 4

    @property
    def parameters(self):
        """ fit parameters
        :return: The fit parameters.
        """
        return self.__fit_params

    @property
    def x_data(self):
        """ The x data sorted and unique used for fitting.
        :return: the x data
        """
        return self.__x_sorted

    @property
    def y_data(self):
        """
        the Y data used for fitting, i.e. the average rate in the specified time window sorted by the x data.
        :return: the average and the standard deviation of the y data
        """
        return self.__y_avg, self.__y_err

    def solve(self, xvalues=None):
        if not xvalues:
            xvalues = self.__x_sorted
        return self.boltzmann(xvalues, *self.__fit_params)

class StimSpikesFile:

    def __init__(self, filename):
        if "stimspikes-1.dat" not in filename:
            filename += os.path.join(os.path.sep, "stimspikes1.dat")
        if not os.path.exists(filename):
            raise ValueError("StimSpikesFile: the given file %s does not exist!" % filename)
        self._filename = filename
        self._data_map = self.__parse_file(filename)

    def __parse_file(self, filename):
        with open(filename, 'r') as f:
            lines = f.readlines()
        
        index_map = {}
        trial_data = []
        trial_duration = 0.0
        index = 0
        trial = 0
            
        for l in lines:
            l = l.strip()
            if "duration:" in l:
                trial_duration = l[1:].strip().split(":")[-1].strip()
                if "sec" in trial_duration:
                    trial_duration = float(trial_duration[:-3])
                elif "ms" in trial_duration:
                    trial_duration = float(trial_duration[:-2]) / 1000
                elif trial_duration[-1] == 's':
                    trial_duration = float(trial_duration[:1])
                else:
                    trial_duration = 0.0
                print(l, trial_duration)
            if "index:" in l:
                if len(trial_data) > 0:
                    index_map[(index, trial)] = (trial_duration, trial_data)
                    trial_data = []
                index = int(l[1:].strip().split(":")[-1])
            if "trial:" in l:
                if len(trial_data) > 0:
                    index_map[(index, trial)] = (trial_duration, trial_data)
                    trial_data = []
                trial = int(l[1:].strip().split(":")[-1])
            if len(l) > 0 and "#" not in l:
                trial_data.append(float(l)/1000)
        index_map[(index, trial)] = (trial_duration, trial_data)
        return index_map

    def get(self, run_index, trial_index):
        if tuple([run_index, trial_index]) not in self._data_map.keys():
            print("Data not found for run %i and trial %i:" % (run_index, trial_index))
            return None, None
        return self._data_map[(run_index, trial_index)]