import numpy as np
import matplotlib.pyplot as plt
from plotstyle import *

def hompoisson(rate, trials, duration) :
    spikes = []
    for k in range(trials) :
        times = []
        t = 0.0
        while t < duration :
            t += np.random.exponential(1/rate)
            times.append( t )
        spikes.append( times )
    return spikes

def inhompoisson(rate, trials, dt) :
    spikes = []
    p = rate*dt
    for k in range(trials) :
        x = np.random.rand(len(rate))
        times = dt*np.nonzero(x<p)[0]
        spikes.append( times )
    return spikes


def pifspikes(input, trials, dt, D=0.1) :
    vreset = 0.0
    vthresh = 1.0
    tau = 1.0
    spikes = []
    for k in range(trials) :
        times = []
        v = vreset
        noise = np.sqrt(2.0*D)*np.random.randn(len(input))/np.sqrt(dt)
        for k in range(len(noise)) :
            v += (input[k]+noise[k])*dt/tau
            if v >= vthresh :
                v = vreset
                times.append(k*dt)
        spikes.append( times )
    return spikes

def isis( spikes ) :
    isi = []
    for k in range(len(spikes)) :
        isi.extend(np.diff(spikes[k]))
    return isi

def plotisih( ax, isis, binwidth=None ) :
    if binwidth == None :
        nperbin = 200.0    # average number of isis per bin
        bins = len(isis)/nperbin  # number of bins
        binwidth = np.max(isis)/bins
        if binwidth < 5e-4 :     # half a millisecond
            binwidth = 5e-4
    h, b = np.histogram(isis, np.arange(0.0, np.max(isis)+binwidth, binwidth), density=True)
    ax.text(0.9, 0.85, 'rate={:.0f}Hz'.format(1.0/np.mean(isis)), ha='right', transform=ax.transAxes)
    ax.text(0.9, 0.75, 'mean={:.0f}ms'.format(1000.0*np.mean(isis)), ha='right', transform=ax.transAxes)
    ax.text(0.9, 0.65, 'CV={:.2f}'.format(np.std(isis)/np.mean(isis)), ha='right', transform=ax.transAxes)
    ax.set_xlabel('ISI', 'ms')
    ax.set_ylabel('p(ISI)', '1/s')
    ax.bar( 1000.0*b[:-1], h, bar_fac*1000.0*np.diff(b), facecolor=colors['blue'])

# parameter:
rate = 20.0
drate = 50.0
trials = 10
duration = 100.0
dt = 0.001
tau = 0.1;

# homogeneous spike trains:
homspikes = hompoisson(rate, trials, duration)

# OU noise:
rng = np.random.RandomState(54637281)
time = np.arange(0.0, duration, dt)
x = np.zeros(time.shape)+rate
n = rng.randn(len(time))*drate*tau/np.sqrt(dt)+rate
for k in range(1,len(x)) :
    x[k] = x[k-1] + (n[k]-x[k-1])*dt/tau
x[x<0.0] = 0.0

# pif spike trains:
inhspikes = pifspikes(x, trials, dt, D=0.3)

fig, (ax1, ax2) = plt.subplots(1, 2)
fig.subplots_adjust(**adjust_fs(fig, top=1.5))
ax1.set_title('stationary')
ax1.set_xlim(0.0, 200.0)
ax1.set_ylim(0.0, 40.0)
plotisih(ax1, isis(homspikes))

ax2.set_title('non-stationary')
ax2.set_xlim(0.0, 200.0)
ax2.set_ylim(0.0, 40.0)
plotisih(ax2, isis(inhspikes))

plt.savefig('isihexamples.pdf')
plt.close()