import numpy as np
import matplotlib.pyplot as plt
from plotstyle import *

    
# parameter:
rate = 20.0
trials = 10
duration = 500.0
dt = 0.001
drate = 50.0
tau = 0.1;


def hompoisson(rate, trials, duration) :
    spikes = []
    for k in range(trials) :
        times = []
        t = 0.0
        while t < duration :
            t += np.random.exponential(1/rate)
            times.append( t )
        spikes.append( times )
    return spikes


def inhompoisson(rate, trials, dt) :
    spikes = []
    p = rate*dt
    for k in range(trials) :
        x = np.random.rand(len(rate))
        times = dt*np.nonzero(x<p)[0]
        spikes.append( times )
    return spikes


def pifspikes(input, trials, dt, D=0.1) :
    vreset = 0.0
    vthresh = 1.0
    tau = 1.0
    spikes = []
    for k in range(trials) :
        times = []
        v = vreset
        noise = np.sqrt(2.0*D)*np.random.randn(len(input))/np.sqrt(dt)
        for k in range(len(noise)) :
            v += (input[k]+noise[k])*dt/tau
            if v >= vthresh :
                v = vreset
                times.append(k*dt)
        spikes.append( times )
    return spikes


def oupifspikes(rate, trials, duration, dt, D, drate, tau):
    # OU noise:
    rng = np.random.RandomState(54637281)
    time = np.arange(0.0, duration, dt)
    x = np.zeros(time.shape)+rate
    n = rng.randn(len(time))*drate*tau/np.sqrt(dt) + rate
    for k in range(1,len(x)) :
        x[k] = x[k-1] + (n[k]-x[k-1])*dt/tau
    x[x<0.0] = 0.0
    spikes = pifspikes(x, trials, dt, D)
    return spikes


def isis( spikes ) :
    isi = []
    for k in range(len(spikes)) :
        isi.extend(np.diff(spikes[k]))
    return np.array(isi)


def plotreturnmap(ax, isis, lag=1, max=1.0) :
    ax.set_xlabel(r'ISI$_i$', 'ms')
    ax.set_ylabel(r'ISI$_{i+1}$', 'ms')
    ax.set_xlim(0.0, 1000.0*max)
    ax.set_ylim(0.0, 1000.0*max)
    isiss = isis[isis<max]
    ax.plot(1000.0*isiss[:-lag], 1000.0*isiss[lag:], clip_on=False, **psAm)


def plotserialcorr(ax, isis, maxlag=10) :
    lags = np.arange(maxlag+1)
    corr = [1.0]
    for lag in lags[1:] :
        corr.append(np.corrcoef(isis[:-lag], isis[lag:])[0,1])
    ax.set_xlabel(r'lag $k$')
    ax.set_ylabel(r'ISI correlation $\rho_k$')
    ax.set_xlim(0.0, maxlag)
    ax.set_ylim(-1.0, 1.0)
    ax.plot([0, 10], [0.0, 0.0], **lsGrid)
    ax.plot(lags, corr, clip_on=False, zorder=100, **lpsAm)


def plot_hom_returnmap(ax, spikes):
    plotreturnmap(ax, isis(spikes)[:200], 1, 0.3)
    ax.set_xticks(np.arange(0.0, 301.0, 100.0))
    ax.set_yticks(np.arange(0.0, 301.0, 100.0))


def plot_inhom_returnmap(ax, spikes):
    plotreturnmap(ax, isis(spikes)[:200], 1, 0.3)
    ax.set_ylabel('')
    ax.set_xticks(np.arange(0.0, 301.0, 100.0))
    ax.set_yticks(np.arange(0.0, 301.0, 100.0))


def plot_hom_serialcorr(ax, spikes):
    plotserialcorr(ax, isis(spikes))
    ax.set_ylim(-0.2, 1.0)


def plot_inhom_serialcorr(ax, spikes):
    plotserialcorr(ax, isis(spikes))
    ax.set_ylabel('')
    ax.set_ylim(-0.2, 1.0)


if __name__ == "__main__":
    homspikes = hompoisson(rate, trials, duration)
    inhomspikes = oupifspikes(rate, trials, duration, dt, 0.3, drate, tau)
    fig, axs = plt.subplots(2, 2, figsize=cm_size(figure_width, 1.8*figure_height))
    fig.subplots_adjust(**adjust_fs(fig, left=6.5, right=1.5))
    plot_hom_returnmap(axs[0,0], homspikes)
    plot_inhom_returnmap(axs[0,1], inhomspikes)
    plot_hom_serialcorr(axs[1,0], homspikes)
    plot_inhom_serialcorr(axs[1,1], inhomspikes)
    plt.savefig('serialcorrexamples.pdf')
    plt.close()