from typing import ValuesView
import numpy as np
import scipy.signal as sig 
from numpy.lib.function_base import iterable
from numpy.lib.index_tricks import diag_indices

def despine(axis, spines=None, hide_ticks=True):
    
    def hide_spine(spine):
        spine.set_visible(False)

    for spine in axis.spines.keys():
        if spines is not None:
            if spine in spines:
                hide_spine(axis.spines[spine])
        else:
            hide_spine(axis.spines[spine])
    if hide_ticks:
        axis.xaxis.set_ticks([])
        axis.yaxis.set_ticks([])
        

def gaussKernel(sigma, dt):
    """ Creates a Gaussian kernel with a given standard deviation and an integral of 1.

    Args:
        sigma (float): The standard deviation of the kernel.
        dt (float): The temporal resolution of the kernel, given in seconds.

    Returns:
        numpy.ndarray : the kernel in the range -4 to +4 sigma
    """
    x = np.arange(-4. * sigma, 4. * sigma, dt)
    y = np.exp(-0.5 * (x / sigma) ** 2) / np.sqrt(2. * np.pi) / sigma
    return y


def extract_am(signal):
    """Extract the amplitude modulation from a signal using the Hilbert transform. Performs padding to avoid artefacts at beginning and end.

    Args:
        signal (np.ndarray): the signal

    Returns:
        np.ndarray: the am, i.e. the absolute value of the Hilbert transform.
    """
    # first add some padding to both ends
    front_pad = np.flip(signal[:int(len(signal)/100)])
    back_pad = np.flip(signal[-int(len(signal)/100):])
    padded = np.hstack((front_pad, signal, back_pad))
    # do the hilbert and take abs, cut away the padding
    am = np.abs(sig.hilbert(padded))
    am = am[len(front_pad):-len(back_pad)]
    return am 



def firing_rate(spikes, duration, sigma=0.005, dt=1./20000.):
    """Convert spike times to a firing rate using the kernel convolution with a Gaussian kernel

    Args:
        spikes (iterable): list of spike times, times should be in seconds
        duration (float): duration of the trial in seconds
        sigma (float, optional): standard deviation of the Gaussian kernel. Defaults to 0.005s.
        dt (float, optional): The stepsize of the trace. Defaults to 1./20000.s.

    Returns:
        np.ndarray: the firing rate
    """
    binary = np.zeros(int(np.round(duration/dt)))
    indices = np.asarray(np.round(spikes / dt), dtype=np.int)
    binary[indices[indices < len(binary)]] = 1
    kernel = gaussKernel(sigma, dt)

    rate = np.convolve(kernel, binary, mode="same")
    return rate


def spiketrain_distance(spikes, duration, dt, kernel_width=0.001):
    """Calculate the Euclidean distance between spike trains. Firing rates are estimated using the kernel
     convloution technique applying a Gaussian kernel of the given standard deviation.

    Args:
        spikes (list of iterable): list of spike trains. event times are given in seconds.
        duration (float): duration of a trial given in seconds.
        dt (float): stepsize of the recording, given in seconds.
        kernel_width (float, optional): standard deviation of the Gaussian kernel used to estimate the firing rate. Defaults to 0.001.

    Returns:
        np.ndarray: the distances
    """
    # perform some checks
    if not isinstance(spikes, list):
        raise ValueError("spikes must be a list of spike trains, aka iterables of spike times.")
    if len(spikes) > 1 and not isinstance(spikes[0], iterable):
        raise ValueError("spikes must be a list of spike trains, aka iterables of spike times.")
    
    rates = np.zeros((len(spikes), int(duration/dt)))
    for i in range(len(spikes)):
        rates[i,:] = firing_rate(spikes[0], duration, kernel_width, dt)
    
    distances = np.zeros((len(spikes), len(spikes)))
    for i in range(len(spikes)):
        for j in range(len(spikes)):
            if i < j:
                distances[i, j] = np.sqrt(np.sum((rates[i,:] - rates[j,:])**2))
                distances[j, i] = distances[i, j]
            elif i == j:
                distances[i, j] = 0.0
            else:
                break

    return distances


def across_group_distance(rates1, rates2, axis=0):
    if axis == 1:
        rates1 = rates1.T
        rates2 = rates2.T
    distances = np.zeros((rates1.shape[axis], rates2.shape[axis]))
    for i in range(distances.shape[0]):
        for j in range(distances.shape[1]):
            distances[i, j] = np.sqrt(np.sum((rates1[i,:] - rates2[j,:])**2))/rates1.shape[1-axis]
    
    return distances


def within_group_distance(rates, axis=0):
    distances = np.zeros((rates.shape[axis], rates.shape[axis]))
    if axis == 1:
        rates = rates.T
    for i in range(distances.shape[0]):
        for j in range(distances.shape[1]):
            if j < i:
                distances[i, j] = np.mean(np.sqrt(np.sum((rates[i,:] - rates[j,:])**2)))/rates.shape[1-axis]
                distances[j, i] = distances[i, j]
            elif i == j:
                distances[i, j] = 0.0
            else:
                break

    return distances