import matplotlib.pyplot as plt
import numpy as np
from IPython import embed

def set_rc():
    plt.rcParams['xtick.labelsize'] = 8
    plt.rcParams['ytick.labelsize'] = 8
    plt.rcParams['xtick.major.size'] = 5
    plt.rcParams['xtick.minor.size'] = 5
    plt.rcParams['xtick.major.width'] = 2
    plt.rcParams['xtick.minor.width'] = 2
    plt.rcParams['ytick.major.size'] = 5
    plt.rcParams['ytick.minor.size'] = 5
    plt.rcParams['ytick.major.width'] = 2
    plt.rcParams['ytick.minor.width'] = 2
    plt.rcParams['xtick.direction'] = "out"
    plt.rcParams['ytick.direction'] = "out"


def create_spikes(isi=0.08, duration=0.5):
    times = np.arange(0., duration, isi)
    times += np.random.randn(len(times)) * (isi / 2.5)
    times = np.delete(times, np.nonzero(times < 0))
    times = np.delete(times, np.nonzero(times > duration))
    times = np.sort(times)
    return times


def gaussian(sigma, dt):
    x = np.arange(-4*sigma, 4*sigma, dt)
    y = np.exp(-0.5 * (x / sigma)**2)/np.sqrt(2*np.pi)/sigma; 
    return x, y

    
def setup_axis(spikes_ax, rate_ax):
    spikes_ax.spines["right"].set_visible(False)
    spikes_ax.spines["top"].set_visible(False)
    spikes_ax.yaxis.set_ticks_position('left')
    spikes_ax.xaxis.set_ticks_position('bottom')
    spikes_ax.set_yticks([0, 1.0])
    spikes_ax.set_ylim([0, 1.05])
    spikes_ax.set_ylabel("spikes", fontsize=9)
    spikes_ax.text(-0.125, 1.2, "A", transform=spikes_ax.transAxes, size=10)

    rate_ax.spines["right"].set_visible(False)
    rate_ax.spines["top"].set_visible(False)
    rate_ax.yaxis.set_ticks_position('left')
    rate_ax.xaxis.set_ticks_position('bottom')
    rate_ax.set_xlabel('time[s]', fontsize=9)
    rate_ax.set_ylabel('firing rate [Hz]', fontsize=9)
    rate_ax.text(-0.125, 1.15, "B", transform=rate_ax.transAxes, size=10)   


def plot_bin_method():
    dt = 1e-5
    duration = 0.5
    
    spike_times = create_spikes(0.018, duration)
    t = np.arange(0., duration, dt)
    
    bins = np.arange(0, 0.55, 0.05)
    count, _ = np.histogram(spike_times, bins)
    
    plt.xkcd()
    set_rc()
    fig = plt.figure()
    fig.set_size_inches(5., 2.5)
    fig.set_facecolor('white')
    spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
    rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
    setup_axis(spikes, rate_ax)
    spikes.set_ylim([0., 1.25])

    spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.25)    
    spikes.vlines(np.hstack((0,bins)), 0., 1.25, color="red", lw=1.5, linestyles='--')
    for i,c in enumerate(count):
        spikes.text(bins[i] + bins[1]/2, 1.05, str(c), fontdict={'color':'red'})
    spikes.set_xlim([0, duration])

    rate = count / 0.05
    rate_ax.step(bins, np.hstack((rate, rate[-1])), where='post')
    rate_ax.set_xlim([0., duration])
    rate_ax.set_ylim([0., 100.])
    rate_ax.set_yticks(np.arange(0,105,25))
    fig.tight_layout()
    fig.savefig("../lecture/images/bin_method.pdf")
    plt.close()


def plot_conv_method():
    dt = 1e-5
    duration = 0.5
    spike_times = create_spikes(0.05, duration) 
    kernel_time, kernel = gaussian(0.02, dt)
    
    t = np.arange(0., duration, dt)
    rate = np.zeros(t.shape)
    rate[np.asarray(np.round(spike_times/dt), dtype=int)] = 1  
    rate = np.convolve(rate, kernel, mode='same')
    rate = np.roll(rate, -1)

    plt.xkcd()
    set_rc()
    fig = plt.figure()
    fig.set_size_inches(5., 2.5)
    fig.set_facecolor('white')
    spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
    rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
    setup_axis(spikes, rate_ax)
    
    spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.5, zorder=2)
    for i in spike_times:
        spikes.plot(kernel_time + i, kernel/np.max(kernel), color="orange", lw=0.75, zorder=1)
    spikes.set_xlim([0, duration])

    rate_ax.plot(t, rate, color="darkblue", lw=1, zorder=2)
    rate_ax.fill_between(t, rate, np.zeros(len(rate)), color="red", alpha=0.5)
    rate_ax.set_xlim([0, duration])
    rate_ax.set_ylim([0, 50])
    rate_ax.set_yticks(np.arange(0,75,25))
    fig.tight_layout()
    fig.savefig("../lecture/images/conv_method.pdf")


def plot_isi_method():
    spike_times = create_spikes(0.09, 0.5)
    
    plt.xkcd()
    set_rc()
    fig = plt.figure()
    fig.set_size_inches(5., 2.5)
    fig.set_facecolor('white')
    
    spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
    rate = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
    setup_axis(spikes, rate)
    
    spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.25)
    spike_times = np.hstack((0, spike_times))
    for i in range(1, len(spike_times)):
        t_start = spike_times[i-1]
        t = spike_times[i]
        spikes.annotate(s='', xy=(t_start, 0.5), xytext=(t,0.5), arrowprops=dict(arrowstyle='<->'), color='red')

    i_rate = 1./np.diff(spike_times)

    rate.step(spike_times, np.hstack((i_rate, i_rate[-1])),color="darkblue", lw=1.25, where="post")
    rate.set_ylim([0, 75])
    rate.set_yticks(np.arange(0,100,25))
    
    fig.tight_layout()
    fig.savefig("../lecture/images/isi_method.pdf")
    
    
if __name__ == '__main__':
    plot_isi_method()
    plot_conv_method()
    plot_bin_method()