import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from plotstyle import *


spikes = [0.724649, 1.67586, 3.02389, 3.57466, 4.15121, 5.00412, 6.64549, 7.81657, 9.77964]


def plot_spikes(ax, spikes):
    ax.show_spines('')
    ax.text(-0.06, 1.1, r'Event times $\{t_i\}$', transform=ax.transAxes)
    ax.annotate('', xy=(0.0, 0.5), xytext=(11.0, 0.5), zorder=-10,
                arrowprops=dict(arrowstyle='<-'))
    ax.text(10.5, -0.2, 'Time')
    for i, t in enumerate(spikes):
        ax.plot([t, t], [0.0, 1.0], **lsA)
        ax.text(t, -0.6, '$t_{%d}$' % i, ha='center')
    ax.set_xlim(0.0, 11.0)
    ax.set_ylim(-0.7, 1.2)
    

def plot_isis(ax, spikes):
    ax.show_spines('')
    ax.text(-0.06, 1.1, r'Intervals $\{T_i\}, \; T_i = t_{i+1} - t_i$', transform=ax.transAxes)
    ax.annotate('', xy=(0.0, 0.5), xytext=(11.0, 0.5), zorder=-10,
                arrowprops=dict(arrowstyle='<-'))
    ax.text(10.5, -0.2, 'Time')
    for i in range(len(spikes)):
        t1 = spikes[i]
        ax.plot([t1, t1], [0., 1.], **lsA)
        if i > 0:
            t0 = spikes[i-1]
            ax.annotate('', xy=(t0, 0.2), xytext=(t1, 0.2), arrowprops=dict(arrowstyle='<->'))
            ax.text(0.5*(t0+t1), -0.6, '$T_{%d}$' % i, ha='center')
    ax.set_xlim(0.0, 11.0)
    ax.set_ylim(-0.7, 1.2)

        
def plot_counts(ax, spikes):
    ax.show_spines('l')
    ax.text(-0.06, 1.1, r'Event counts $\{ n_i \}$', transform=ax.transAxes)
    ax.annotate('', xy=(0.0, 0.0), xytext=(11.0, 0.0), zorder=-10,
                arrowprops=dict(arrowstyle='<-'))
    ax.text(10.5, -2.0, 'Time')
    c = 0
    t0 = 0.0
    for i, t in enumerate(spikes):
        ax.plot([t0, t], [c, c], zorder=10, clip_on=False, **lsBm)
        ax.plot([t], [c], zorder=20, clip_on=False, **psBo)
        c += 1
        ax.plot([t], [c], zorder=30, clip_on=False, **psB)
        t0 = t
    ax.plot([t0, 10.5], [c, c], zorder=10, **lsBm)
    ax.set_xlim(0.0, 11.0)
    ax.set_ylim(0, 11)
    ax.set_yticks(np.arange(0.0, 11.0, 2.0))


if __name__ == "__main__":
    fig = plt.figure(figsize=cm_size(figure_width, 1.7*figure_height))
    gs = gridspec.GridSpec(3, 1, height_ratios=[2, 2, 4])
    gs.update(hspace=0.6, **adjust_fs(fig, bottom=1.5, top=1.5, left=3.0, right=1.5))
    ax1, ax2, ax3 = [fig.add_subplot(gs[i,0]) for i in range(3)]
    plot_spikes(ax1, spikes)
    plot_isis(ax2, spikes)
    plot_counts(ax3, spikes)
    plt.savefig('pointprocesssketch.pdf')
    plt.close()