import numpy as np
from scipy.stats import poisson
import matplotlib.pyplot as plt
from plotstyle import *


rate = 20.0
trials = 20
duration = 100.0


def hompoisson(rate, trials, duration) :
    spikes = []
    for k in range(trials) :
        times = []
        t = 0.0
        while t < duration :
            t += np.random.exponential(1/rate)
            times.append( t )
        spikes.append( times )
    return spikes

    
def plot_count_hist(ax, spikes, win, pmax):
    counts = []
    for times in spikes:
        c, _ = np.histogram(times, np.arange(0.0, duration, win))
        counts.extend(c)
    cb = np.arange(0.0, 15.5, 1.0)
    h, b = np.histogram(counts, cb, density=True)
    ax.bar(b[:-1], h, bar_fac, align='center', **fsA)
    ax.plot(cb, poisson.pmf(cb, rate*win), **lsBm)
    ax.plot(cb, poisson.pmf(cb, rate*win), **psBm)
    ax.text(0.9, 0.9, 'T=%.0fms' % (1000.0*win), transform=ax.transAxes, ha='right')
    ax.set_xlim(-0.5, 10.5)
    ax.set_ylim(0.0, pmax)
    ax.set_xticks(np.arange(0.0, 11.0, 5.0))
    ax.set_xlabel('Counts k')
    ax.set_ylabel('P(k)')

    
if __name__ == "__main__":
    spikes = hompoisson(rate, trials, duration)
    fig, (ax1, ax2) = plt.subplots(1, 2)
    fig.subplots_adjust(**adjust_fs(fig, top=0.5, right=1.5))
    plot_count_hist(ax1, spikes, 0.02, 0.7)
    ax1.set_yticks(np.arange(0.0, 0.7, 0.2))
    plot_count_hist(ax2, spikes, 0.2, 0.22)
    ax2.set_yticks(np.arange(0.0, 0.25, 0.1))
    plt.savefig('countexamples.pdf')
    plt.close()