import numpy as np
import scipy.io as spio
import scipy.stats as spst
import scipy as sp
import matplotlib.pyplot as plt
from plotstyle import *


def get_instantaneous_rate(times, max_t=30., dt=1e-4):
    time = np.arange(0., max_t, dt)
    indices = np.asarray(times / dt, dtype=int)
    intervals = np.diff(np.hstack(([0], times)))
    inst_rate = np.zeros(time.shape)
    
    for i, index in enumerate(indices[1:]):
        inst_rate[indices[i-1]:indices[i]] = 1/intervals[i]
    return time, inst_rate


def plot_isi_rate(spike_times, max_t=30, dt=1e-4):
    times = np.squeeze(spike_times[0][0])[:50000]
    time, rate = get_instantaneous_rate(times, max_t=50000*dt)
    
    rates = np.zeros((len(rate), len(spike_times)))
    for i in range(len(spike_times)):
        _, rates[:, i] = get_instantaneous_rate(np.squeeze(spike_times[i][0])[:50000],
                                                max_t=50000*dt)
    avg_rate = np.mean(rates, axis=1)
    rate_std = np.std(rates, axis=1)
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=cm_size(figure_width, 1.2*figure_height))
    
    ax1.vlines(times[times < (50000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
    ax1.set_ylabel('Spikes')
    ax1.set_xlim(0, 5)
    ax2.plot(time, rate, label="instantaneous rate, trial 1")
    ax2.set_ylabel('Firing rate', 'Hz')

    ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
                     alpha=0.5, label="standard deviation")
    ax3.plot(time, avg_rate, label="average rate")
    ax3.set_xlabel('Time', 's')
    ax3.set_ylabel('Firing rate', 'Hz')
    ax3.legend()
    ax3.set_ylim(0, 450)
    
    fig.savefig("isimethod.pdf")
    plt.close()


def get_binned_rate(times, bin_width=0.05, max_t=30., dt=1e-4):
    time = np.arange(0., max_t, dt)
    bins = np.arange(0., max_t, bin_width)
    bin_indices = np.asarray(bins / dt, np.int)
    hist, _ = sp.histogram(times, bins)
    rate = np.zeros(time.shape)
    
    for i, b in enumerate(bin_indices[1:]):
        rate[bin_indices[i-1]:b] = hist[i-1]/bin_width
    return time, rate


def plot_bin_rate(spike_times, bin_width, max_t=30, dt=1e-4):
    times = np.squeeze(spike_times[0][0])
    time, rate = get_binned_rate(times)
    rates = np.zeros((len(rate), len(spike_times)))
    for i in range(len(spike_times)):
        _, rates[:, i] = get_binned_rate(np.squeeze(spike_times[i][0]))
    avg_rate = np.mean(rates, axis=1)
    rate_std = np.std(rates, axis=1)
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=cm_size(figure_width, 1.2*figure_height))

    ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
    ax1.set_ylabel('Skpikes')
    ax1.set_xlim(0, 5)

    ax2.plot(time, rate, label="binned rate, trial 1")
    ax2.set_ylabel('Firing rate', 'Hz')
    ax2.legend()
    ax2.set_xlim(0, 5)

    ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
                     alpha=0.5, label="standard deviation")
    ax3.plot(time, avg_rate, label="average rate")
    ax3.set_xlabel('Times', 's')
    ax3.set_ylabel('Firing rate', 'Hz')
    ax3.legend()
    ax3.set_xlim(0, 5)
    ax3.set_ylim(0, 450)

    fig.savefig("binmethod.pdf")
    plt.close()
  

def get_convolved_rate(times, sigma, max_t=30., dt=1.e-4):
    time = np.arange(0., max_t, dt)
    kernel = spst.norm.pdf(np.arange(-8*sigma, 8*sigma, dt),loc=0,scale=sigma)
    indices = np.asarray(times/dt, dtype=int)
    rate = np.zeros(time.shape)
    rate[indices] = 1.;
    conv_rate = np.convolve(rate, kernel, mode="same")
    return time, conv_rate


def plot_conv_rate(spike_times, sigma=0.05, max_t=30, dt=1e-4):
    times = np.squeeze(spike_times[0][0])
    time, rate = get_convolved_rate(times, sigma)
   
    rates = np.zeros((len(rate), len(spike_times)))
    for i in range(len(spike_times)):
        _, rates[:, i] = get_convolved_rate(np.squeeze(spike_times[i][0]), sigma)
    avg_rate = np.mean(rates, axis=1)
    rate_std = np.std(rates, axis=1)
    
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=cm_size(figure_width, 1.2*figure_height))

    ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
    ax1.set_ylabel('Skpikes')
    ax1.set_xlim(0, 5)

    ax2.plot(time, rate, label="convolved rate, trial 1")
    ax2.set_ylabel('Firing rate', 'Hz')
    ax2.legend()
    ax2.set_xlim(0, 5)

    ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
                     alpha=0.5, label="standard deviation")
    ax3.plot(time, avg_rate, label="average rate")
    ax3.set_xlabel('Times', 's')
    ax3.set_ylabel('Firing rate', 'Hz')
    ax3.legend()
    ax3.set_xlim(0, 5)
    ax3.set_ylim(0, 450)

    fig.savefig("convmethod.pdf")
    plt.close()


def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
    times = np.squeeze(spike_times[0][0])
    time, conv_rate = get_convolved_rate(times, bin_width/np.sqrt(12.))
    time, inst_rate = get_instantaneous_rate(times)
    time, binn_rate = get_binned_rate(times, bin_width)
   
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=cm_size(figure_width, 1.8*figure_height))
    fig.subplots_adjust(**adjust_fs(fig, left=6.0, right=1.5, bottom=3.0, top=1.0))

    ax1.show_spines('b')
    ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color=colors['blue'], lw=1.5)
    #ax1.set_ylabel('Spikes')
    ax1.text(-0.105, 0.5, 'Spikes', transform=ax1.transAxes,
             rotation='vertical', va='center')
    ax1.set_xlim(2.5, 3.5)
    ax1.set_ylim(-0.2, 1)
    ax1.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
    ax1.set_xticklabels([])

    ax2.plot(time, inst_rate, color=colors['orange'], lw=2)
    ax2.text(1.0, 1.0, 'instantaneous rate', transform=ax2.transAxes, ha='right')
    #ax2.set_ylabel('Rate', 'Hz')
    ax2.text(-0.105, 0.5, axis_label('Rate', 'Hz'), transform=ax2.transAxes,
             rotation='vertical', va='center')
    ax2.set_xlim(2.5, 3.5)
    ax2.set_ylim(0, 300)
    ax2.set_yticks(np.arange(0, 400, 100))
    ax2.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
    ax2.set_xticklabels([])

    ax3.plot(time, binn_rate, color=colors['orange'], lw=2)
    ax3.text(1.0, 1.0, 'binned rate', transform=ax3.transAxes, ha='right')
    #ax3.set_ylabel('Rate', 'Hz')
    ax3.text(-0.105, 0.5, axis_label('Rate', 'Hz'), transform=ax3.transAxes,
             rotation='vertical', va='center')
    ax3.set_xlim(2.5, 3.5)
    ax3.set_ylim(0, 300)
    ax3.set_yticks(np.arange(0, 400, 100))
    ax3.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
    ax3.set_xticklabels([])

    ax4.plot(time, conv_rate, color=colors['orange'], lw=2)
    ax4.text(1.0, 1.0, 'convolved rate', transform=ax4.transAxes, ha='right')
    ax4.set_xlabel('Time', 's')
    #ax4.set_ylabel('Rate', 'Hz')
    ax4.text(-0.105, 0.5, axis_label('Rate', 'Hz'), transform=ax4.transAxes,
             rotation='vertical', va='center')
    ax4.set_xlim(2.5, 3.5)
    ax4.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
    ax4.set_ylim(0, 300)
    ax4.set_yticks(np.arange(0, 400, 100))

    fig.savefig("firingrates.pdf")
    plt.close()


if __name__ == "__main__":
    spike_times =  spio.loadmat('lifoustim.mat')["spikes"]
    #plot_isi_rate(spike_times)
    #plot_bin_rate(spike_times, 0.05)
    #plot_conv_rate(spike_times, 0.025)
    plot_comparison(spike_times, 0.05, 0.025)