import numpy as np
import matplotlib.pyplot as plt
import scipy.io as spio
import scipy.stats as spst
import scipy as sp
from IPython import embed


def set_axis_fontsize(axis, label_size, tick_label_size=None, legend_size=None):
    """
        Sets axis, tick label and legend font sizes to the desired size.

    :param axis: the axes object
    :param label_size: the size of the axis label
    :param tick_label_size: the size of the tick labels. If None, lable_size is used
    :param legend_size: the size of the font used in the legend.If None, the label_size is used.

    """
    if not tick_label_size:
        tick_label_size = label_size
    if not legend_size:
        legend_size = label_size
    axis.xaxis.get_label().set_fontsize(label_size)
    axis.yaxis.get_label().set_fontsize(label_size)
    for tick in axis.xaxis.get_major_ticks() + axis.yaxis.get_major_ticks():
        tick.label.set_fontsize(tick_label_size)

    l = axis.get_legend()
    if l:
        for t in l.get_texts():
            t.set_fontsize(legend_size)


def get_instantaneous_rate(times, max_t=30., dt=1e-4):
    time = np.arange(0., max_t, dt)
    indices = np.asarray(times / dt, dtype=int)
    intervals = np.diff(np.hstack(([0], times)))
    inst_rate = np.zeros(time.shape)
    
    for i, index in enumerate(indices[1:]):
        inst_rate[indices[i-1]:indices[i]] = 1/intervals[i]
    return time, inst_rate


def plot_isi_rate(spike_times, max_t=30, dt=1e-4):
    times = np.squeeze(spike_times[0][0])[:50000]
    time, rate = get_instantaneous_rate(times, max_t=50000*dt)
    
    rates = np.zeros((len(rate), len(spike_times)))
    for i in range(len(spike_times)):
        _, rates[:, i] = get_instantaneous_rate(np.squeeze(spike_times[i][0])[:50000],
                                                max_t=50000*dt)
    avg_rate = np.mean(rates, axis=1)
    rate_std = np.std(rates, axis=1)
    
    fig = plt.figure()
    ax1 = fig.add_subplot(311)
    ax2 = fig.add_subplot(312)
    ax3 = fig.add_subplot(313)
    
    ax1.vlines(times[times < (50000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
    ax1.set_ylabel("skpikes", fontsize=12)
    set_axis_fontsize(ax1, 12)
    ax1.set_xlim([0, 5])
    ax2.plot(time, rate, label="instantaneous rate, trial 1")
    ax2.set_ylabel("firing rate [Hz]", fontsize=12)
    ax2.legend(fontsize=12)
    set_axis_fontsize(ax2, 12)

    ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
                     alpha=0.5, label="standard deviation")
    ax3.plot(time, avg_rate, label="average rate")
    ax3.set_xlabel("times [s]", fontsize=12)
    ax3.set_ylabel("firing rate [Hz]", fontsize=12)
    ax3.legend(fontsize=12)
    ax3.set_ylim([0, 450])
    set_axis_fontsize(ax3, 12)
    
    fig.set_size_inches(15, 10)
    fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
    fig.set_facecolor("white")
    fig.savefig("figures/instantaneous_rate.png")
    plt.close()


def get_binned_rate(times, bin_width=0.05, max_t=30., dt=1e-4):
    time = np.arange(0., max_t, dt)
    bins = np.arange(0., max_t, bin_width)
    bin_indices = bins / dt
    hist, _ = sp.histogram(times, bins)
    rate = np.zeros(time.shape)
    
    for i, b in enumerate(bin_indices[1:]):
        rate[bin_indices[i-1]:b] = hist[i-1]/bin_width
    return time, rate


def plot_bin_rate(spike_times, bin_width, max_t=30, dt=1e-4):
    times = np.squeeze(spike_times[0][0])
    time, rate = get_binned_rate(times)
    rates = np.zeros((len(rate), len(spike_times)))
    for i in range(len(spike_times)):
        _, rates[:, i] = get_binned_rate(np.squeeze(spike_times[i][0]))
    avg_rate = np.mean(rates, axis=1)
    rate_std = np.std(rates, axis=1)
    
    fig = plt.figure()
    ax1 = fig.add_subplot(311)
    ax2 = fig.add_subplot(312)
    ax3 = fig.add_subplot(313)

    ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
    ax1.set_ylabel("skpikes", fontsize=12)
    ax1.set_xlim([0, 5])
    set_axis_fontsize(ax1, 12)

    ax2.plot(time, rate, label="binned rate, trial 1")
    ax2.set_ylabel("firing rate [Hz]", fontsize=12)
    ax2.legend(fontsize=12)
    ax2.set_xlim([0, 5])
    set_axis_fontsize(ax2, 12)

    ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
                     alpha=0.5, label="standard deviation")
    ax3.plot(time, avg_rate, label="average rate")
    ax3.set_xlabel("times [s]", fontsize=12)
    ax3.set_ylabel("firing rate [Hz]", fontsize=12)
    ax3.legend(fontsize=12)
    ax3.set_xlim([0, 5])
    ax3.set_ylim([0, 450])
    set_axis_fontsize(ax3, 12)

    fig.set_size_inches(15, 10)
    fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
    fig.set_facecolor("white")
    fig.savefig("figures/binned_rate.png")
    plt.close()
  

def get_convolved_rate(times, sigma, max_t=30., dt=1.e-4):
    time = np.arange(0., max_t, dt)
    kernel = spst.norm.pdf(np.arange(-8*sigma, 8*sigma, dt),loc=0,scale=sigma)
    indices = np.asarray(times/dt, dtype=int)
    rate = np.zeros(time.shape)
    rate[indices] = 1.;
    conv_rate = np.convolve(rate, kernel, mode="same")
    return time, conv_rate


def plot_conv_rate(spike_times, sigma=0.05, max_t=30, dt=1e-4):
    times = np.squeeze(spike_times[0][0])
    time, rate = get_convolved_rate(times, sigma)
   
    rates = np.zeros((len(rate), len(spike_times)))
    for i in range(len(spike_times)):
        _, rates[:, i] = get_convolved_rate(np.squeeze(spike_times[i][0]), sigma)
    avg_rate = np.mean(rates, axis=1)
    rate_std = np.std(rates, axis=1)
    
    fig = plt.figure()
    ax1 = fig.add_subplot(311)
    ax2 = fig.add_subplot(312)
    ax3 = fig.add_subplot(313)

    ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
    ax1.set_ylabel("skpikes", fontsize=12)
    ax1.set_xlim([0, 5])
    set_axis_fontsize(ax1, 12)

    ax2.plot(time, rate, label="convolved rate, trial 1")
    ax2.set_ylabel("firing rate [Hz]", fontsize=12)
    ax2.legend(fontsize=12)
    ax2.set_xlim([0, 5])
    set_axis_fontsize(ax2, 12)

    ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
                     alpha=0.5, label="standard deviation")
    ax3.plot(time, avg_rate, label="average rate")
    ax3.set_xlabel("times [s]", fontsize=12)
    ax3.set_ylabel("firing rate [Hz]", fontsize=12)
    ax3.legend(fontsize=12)
    ax3.set_xlim([0, 5])
    ax3.set_ylim([0, 450])
    set_axis_fontsize(ax3, 12)

    fig.set_size_inches(15, 10)
    fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
    fig.set_facecolor("white")
    fig.savefig("figures/convolved_rate.png")
    plt.close()


def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
    times = np.squeeze(spike_times[0][0])
    time, conv_rate = get_convolved_rate(times, bin_width/np.sqrt(12.))
    time, inst_rate = get_instantaneous_rate(times)
    time, binn_rate = get_binned_rate(times, bin_width)
   
    fig = plt.figure()
    ax1 = fig.add_subplot(411)
    ax2 = fig.add_subplot(412)
    ax3 = fig.add_subplot(413)
    ax4 = fig.add_subplot(414)
    ax1.spines["right"].set_visible(False)
    ax1.spines["top"].set_visible(False)
    ax1.yaxis.set_ticks_position('left')
    ax1.xaxis.set_ticks_position('bottom')
    ax2.spines["right"].set_visible(False)
    ax2.spines["top"].set_visible(False)
    ax2.yaxis.set_ticks_position('left')
    ax2.xaxis.set_ticks_position('bottom')
    ax3.spines["right"].set_visible(False)
    ax3.spines["top"].set_visible(False)
    ax3.yaxis.set_ticks_position('left')
    ax3.xaxis.set_ticks_position('bottom')
    ax4.spines["right"].set_visible(False)
    ax4.spines["top"].set_visible(False)
    ax4.yaxis.set_ticks_position('left')
    ax4.xaxis.set_ticks_position('bottom')

    ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
    ax1.set_ylabel("spikes", fontsize=10)
    ax1.set_xlim([2.5, 3.5])
    ax1.set_ylim([0, 1])
    ax1.set_yticks([0, 1])
    set_axis_fontsize(ax1, 10)
    ax1.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
    ax1.set_xticklabels([])

    ax2.plot(time, inst_rate, lw=1.5, label="instantaneous rate")
    ax2.set_ylabel("firing rate [Hz]", fontsize=10)
    ax2.legend(fontsize=10)
    ax2.set_xlim([2.5, 3.5])
    ax2.set_ylim([0, 300])
    ax2.set_yticks(np.arange(0, 400, 100))
    set_axis_fontsize(ax2, 10)
    ax2.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
    ax2.set_xticklabels([])

    ax3.plot(time, binn_rate, lw=1.5, label="binned rate")
    ax3.set_ylabel("firing rate [Hz]", fontsize=10)
    ax3.legend(fontsize=10)
    ax3.set_xlim([2.5, 3.5])
    ax3.set_ylim([0, 300])
    ax3.set_yticks(np.arange(0, 400, 100))
    set_axis_fontsize(ax3, 10)
    ax3.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
    ax3.set_xticklabels([])

    ax4.plot(time, conv_rate, lw=1.5, label="convolved rate")
    ax4.set_xlabel("time [s]", fontsize=10)
    ax4.set_ylabel("firing rate [Hz]", fontsize=10)
    ax4.legend(fontsize=10)
    ax4.set_xlim([2.5, 3.5])
    ax4.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
    ax4.set_ylim([0, 300])
    ax4.set_yticks(np.arange(0, 400, 100))
    set_axis_fontsize(ax4, 10)

    fig.set_size_inches(7.5, 5)
    fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95, )
    fig.set_facecolor("white")
    fig.savefig("firingrates.pdf")
    plt.close()


if __name__ == "__main__":
    spike_times =  spio.loadmat('lifoustim.mat')["spikes"]
    # plot_isi_rate(spike_times)
    # plot_bin_rate(spike_times, 0.05)
    # plot_conv_rate(spike_times, 0.025)
    plot_comparison(spike_times, 0.05, 0.025)