import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from plotstyle import *

def random_walk(n, p, rng):
    steps = rng.rand(n)
    steps[steps>=1.0-p] = 1
    steps[steps<1.0-p] = -1
    x = np.hstack(((0.0,), np.cumsum(steps)))
    return x

def random_walk_neuron(n, p, thresh, rng):
    x = random_walk(n, p, rng)
    spikes = []
    j = -1
    i = 1
    while i > 0:
        if j > 0:
            spikes.append(j+i)
        j += i + 1
        x[j:] -= x[j]
        i = np.argmax(x[j:] >= thresh - 0.1)
    return x, spikes
    

if __name__ == "__main__":
    fig = plt.figure()
    spec = gridspec.GridSpec(nrows=1, ncols=2, width_ratios=[2, 1], wspace=0.4,
                             **adjust_fs(fig, left=6.5, right=1.5))
    ax = fig.add_subplot(spec[0, 0])
    rng = np.random.RandomState(52281)
    p = 0.6
    thresh = 10.0
    nmax = 500
    n = np.arange(0.0, nmax+1, 1.0)
    x, spikes = random_walk_neuron(nmax, p, thresh, rng)
    ax.axhline(0.0, **lsGrid)
    ax.axhline(thresh, **lsAm)
    ax.plot(n, x, **lsBm)
    for tspike in spikes:
        ax.plot([tspike, tspike], [12.0, 16.0], **lsC)
    ax.set_xlabel('Time')
    ax.set_ylabel('Potential')
    ax.set_xlim(0, nmax)
    ax.set_ylim(-10, 17)
    ax.set_yticks(np.arange(-10, 11, 10))
    ax = fig.add_subplot(spec[0, 1])
    nmax = 100000
    x, spikes = random_walk_neuron(nmax, p, thresh, rng)
    isis = np.diff(spikes)
    ax.hist(isis, np.arange(0.0, 151.0, 10.0), **fsAs)
    ax.set_xlabel('ISI')
    ax.set_ylabel('Count')
    ax.set_xticks(np.arange(0, 151, 50))
    ax.set_yticks(np.arange(0, 401, 100))
    fig.savefig("randomwalkneuron.pdf")
    plt.close()