import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from plotstyle import *


fig_size = cm_size(figure_width, figure_height)


def create_spikes(nspikes=11, duration=0.5, seed=1000):
    rng = np.random.RandomState(seed)
    x = np.linspace(0.0, 1.0, nspikes)
    # double gaussian rate profile:
    rate = np.exp(-0.5*((x-0.35)/0.25)**2.0)
    rate += 1.*np.exp(-0.5*((x-0.9)/0.05)**2.0)
    isis = 1.0/rate
    isis += rng.randn(len(isis))*0.2
    times = np.cumsum(isis)
    times *= 1.05*duration/times[-1]
    times += 0.01
    return times


def gaussian(sigma, dt):
    x = np.arange(-4*sigma, 4*sigma, dt)
    y = np.exp(-0.5 * (x / sigma)**2)/np.sqrt(2*np.pi)/sigma; 
    return x, y

    
def setup_axis(spikes_ax, rate_ax):
    spikes_ax.show_spines('')
    spikes_ax.set_yticks([])
    spikes_ax.set_ylim(-0.2, 1.0)
    spikes_ax.text(-0.11, 0.5, 'Spikes', transform=spikes_ax.transAxes, rotation='vertical', va='center')
    spikes_ax.set_xlim(-1, 500)
    spikes_ax.set_xticklabels([])
    #spikes_ax.set_xticklabels(np.arange(0., 600, 100))

    spikes_ax.show_spines('lb')
    rate_ax.set_xlabel('Time', 'ms')
    #rate_ax.set_ylabel('Firing rate', 'Hz')
    rate_ax.text(-0.11, 0.5, axis_label('Rate', 'Hz'), transform=rate_ax.transAxes,
                 rotation='vertical', va='center')
    rate_ax.set_xlim(0, 500)
    #rate_ax.set_xticklabels(np.arange(0., 600, 100))
    rate_ax.set_ylim(0, 60)
    rate_ax.set_yticks(np.arange(0,65,20))


def plot_bin_method():
    dt = 1e-5
    duration = 0.5
    
    spike_times = create_spikes()
    t = np.arange(0., duration, dt)
    
    bins = np.arange(0, 0.55, 0.05)
    count, _ = np.histogram(spike_times, bins)

    fig = plt.figure(figsize=fig_size)
    spec = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[3, 4], hspace=0.2,
                             **adjust_fs(fig, left=5.5, right=1.5, top=1.5))
    spikes_ax = fig.add_subplot(spec[0, 0])
    rate_ax = fig.add_subplot(spec[1, 0])
    setup_axis(spikes_ax, rate_ax)

    for ti in spike_times:
        ti *= 1000.0
        spikes_ax.plot([ti, ti], [0., 1.], color=colors['blue'], lw=2)

    for tb in 1000.0*bins :
        spikes_ax.plot([tb, tb], [-2.0, 0.75], '-', color="#777777", lw=1, clip_on=False)
    for i,c in enumerate(count):
        spikes_ax.text(1000.0*(bins[i]+0.5*bins[1]), 1.1, str(c), color=colors['red'],
                       ha='center')

    rate = count / 0.05
    rate_ax.step(1000.0*bins, np.hstack((rate, rate[-1])), color=colors['orange'], lw=2, where='post')
    fig.savefig("binmethod.pdf")
    plt.close()


def plot_conv_method():
    dt = 1e-5
    duration = 0.5
    spike_times = create_spikes() 
    kernel_time, kernel = gaussian(0.015, dt)
    
    t = np.arange(0., duration, dt)
    rate = np.zeros(t.shape)
    rate[np.asarray(np.round(spike_times[:-1]/dt), dtype=int)] = 1  
    rate = np.convolve(rate, kernel, mode='same')
    rate = np.roll(rate, -1)

    fig = plt.figure(figsize=fig_size)
    spec = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[3, 4], hspace=0.2,
                             **adjust_fs(fig, left=5.5, right=1.5, top=1.5))
    spikes_ax = fig.add_subplot(spec[0, 0])
    rate_ax = fig.add_subplot(spec[1, 0])
    setup_axis(spikes_ax, rate_ax)
    
    for ti in spike_times:
        ti *= 1000.0
        spikes_ax.plot([ti, ti], [0., 1.], color=colors['blue'], lw=2)
        spikes_ax.plot(1000*kernel_time + ti, kernel/np.max(kernel), color=colors['red'],
                       lw=1, zorder=1)

    rate_ax.plot(1000.0*t, rate, color=colors['orange'], lw=2, zorder=2)
    rate_ax.fill_between(1000.0*t, rate, np.zeros(len(rate)), color=colors['yellow'])
    
    fig.savefig("convmethod.pdf")


def plot_isi_method():
    spike_times = create_spikes()
    
    fig = plt.figure(figsize=fig_size)
    spec = gridspec.GridSpec(nrows=2, ncols=1, height_ratios=[3, 4], hspace=0.2,
                             **adjust_fs(fig, left=5.5, right=1.5, top=1.5))
    spikes_ax = fig.add_subplot(spec[0, 0])
    rate_ax = fig.add_subplot(spec[1, 0])
    setup_axis(spikes_ax, rate_ax)
    
    spike_times = np.hstack((0.005, spike_times))
    for i in range(1, len(spike_times)):
        t_start = 1000*spike_times[i-1]
        t = 1000*spike_times[i]
        spikes_ax.plot([t_start, t_start], [0., 1.], color=colors['blue'], lw=2)
        spikes_ax.annotate('', xy=(t_start, 0.5), xytext=(t,0.5), arrowprops=dict(arrowstyle='<->'), color=colors['red'])
        spikes_ax.text(0.5*(t_start+t), 1.05, 
                    "{0:.0f}".format((t - t_start)),
                    color=colors['red'], ha='center')

    #spike_times = np.hstack((0, spike_times))
    i_rate = 1./np.diff(spike_times)
    rate_ax.step(1000*spike_times, np.hstack((i_rate, i_rate[-1])),color=colors['orange'], lw=2, where="post")
    
    fig.savefig("isimethod.pdf")
    
    
if __name__ == '__main__':
    plot_isi_method()
    plot_conv_method()
    plot_bin_method()