import os
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as mpt
from plotstyle import *


rate = 20.0
trials = 20
duration = 500.0
dt = 0.001
drate = 50.0
tau = 0.1;

    
def hompoisson(rate, trials, duration) :
    spikes = []
    for k in range(trials) :
        times = []
        t = 0.0
        while t < duration :
            t += np.random.exponential(1/rate)
            times.append(t)
        spikes.append(times)
    return spikes


def inhompoisson(rate, trials, dt) :
    spikes = []
    p = rate*dt
    for k in range(trials) :
        x = np.random.rand(len(rate))
        times = dt*np.nonzero(x<p)[0]
        spikes.append(times)
    return spikes


def pifspikes(input, trials, dt, D=0.1) :
    vreset = 0.0
    vthresh = 1.0
    tau = 1.0
    spikes = []
    for k in range(trials) :
        times = []
        v = vreset
        noise = np.sqrt(2.0*D)*np.random.randn(len(input))/np.sqrt(dt)
        for k in range(len(noise)) :
            v += (input[k]+noise[k])*dt/tau
            if v >= vthresh :
                v = vreset
                times.append(k*dt)
        spikes.append(times)
    return spikes


def oupifspikes(rate, trials, duration, dt, D, drate, tau):
    # OU noise:
    rng = np.random.RandomState(54637281)
    time = np.arange(0.0, duration, dt)
    x = np.zeros(time.shape)+rate
    n = rng.randn(len(time))*drate*tau/np.sqrt(dt) + rate
    for k in range(1,len(x)) :
        x[k] = x[k-1] + (n[k]-x[k-1])*dt/tau
    x[x<0.0] = 0.0
    spikes = pifspikes(x, trials, dt, D)
    return spikes


def count_stats(spikes, wins):
    mean_counts = np.zeros(len(wins))
    var_counts = np.zeros(len(wins))
    for k, win in enumerate(wins):
        counts = []
        for times in spikes:
            c, _ = np.histogram(times, np.arange(0.0, duration, win))
            counts.extend(c)
        mean_counts[k] = np.mean(counts)
        var_counts[k] = np.var(counts)
    return mean_counts, var_counts


def plot_count_fano(ax1, ax2, wins, mean_counts, var_counts):
    ax1.plot(mean_counts, var_counts, zorder=100, **lsA)
    ax1.set_xlabel('Mean count')
    ax1.set_xlim(0.0, 20.0)
    ax1.set_ylim(0.0, 20.0)
    ax1.set_xticks(np.arange(0.0, 21.0, 10.0))
    ax1.set_yticks(np.arange(0.0, 21.0, 10.0))
    ax2.plot(1000.0*wins, var_counts/mean_counts, **lsB)
    ax2.set_xlabel('Window', 'ms')
    ax2.set_ylim(0.0, 1.2)
    ax2.set_xscale('log')
    ax2.set_xticks([10, 100, 1000])
    ax2.set_xticklabels(['10', '100', '1000'])
    ax2.xaxis.set_minor_locator(mpt.NullLocator())
    ax2.set_yticks(np.arange(0.0, 1.2, 0.5))


def plot_fano(ax, wins, mean_counts, var_counts):
    ax.plot(1000.0*wins, var_counts/mean_counts, **lsB)
    ax.set_xlabel('Window', 'ms')
    ax.set_ylim(0.0, 1.2)
    ax.set_xscale('log')
    ax.set_xticks([1, 10, 100, 1000])
    ax.set_xticklabels(['1', '10', '100', '1000'])
    ax.xaxis.set_minor_locator(mpt.NullLocator())
    ax.set_yticks(np.arange(0.0, 1.2, 0.5))

    
if __name__ == "__main__":
    if not os.path.exists('fanoexamples.json'):
        homspikes = hompoisson(rate, trials, duration)
        inhspikes = oupifspikes(rate, trials, duration, dt, 0.3, drate, tau)
        wins = np.logspace(-3, 0.0, 100)
        hom_mean_counts, hom_var_counts = count_stats(homspikes, wins)
        inh_mean_counts, inh_var_counts = count_stats(inhspikes, wins)
        with open('fanoexamples.json', 'w') as df:
            json.dump({'wins': wins.tolist(),
                       'hom_mean_counts': hom_mean_counts.tolist(),
                       'hom_var_counts': hom_var_counts.tolist(),
                       'inh_mean_counts': inh_mean_counts.tolist(),
                       'inh_var_counts': inh_var_counts.tolist()}, df, indent=4)
    else:
        with open('fanoexamples.json', 'r') as sf:
            data = json.load(sf)
            wins = np.array(data['wins'])
            hom_mean_counts = np.array(data['hom_mean_counts'])
            hom_var_counts = np.array(data['hom_var_counts'])
            inh_mean_counts = np.array(data['inh_mean_counts'])
            inh_var_counts = np.array(data['inh_var_counts'])
    fig, (ax1, ax2) = plt.subplots(1, 2)
    fig.subplots_adjust(**adjust_fs(fig, top=0.5, right=2.0))
    plot_fano(ax1, wins, hom_mean_counts, hom_var_counts)
    ax1.set_ylabel('Fano factor')
    ax1.text(0.1, 0.95, 'Poisson', transform=ax1.transAxes)
    ax2.axhline(1.0, **lsGrid)
    plot_fano(ax2, wins, inh_mean_counts, inh_var_counts)
    ax2.annotate('', xy=(45.0, 0.0), xytext=(45.0, 0.4), arrowprops=dict(arrowstyle="->"))
    ax2.text(60.0, 0.25, 'most\nreliable', va='center')
    ax2.text(0.1, 0.95, 'OU noise', transform=ax2.transAxes)
    plt.savefig('fanoexamples.pdf')
    plt.close()