import numpy as np
import matplotlib.pyplot as plt
import scipy.io as scio
from IPython import embed


def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1):
    count = 0
    sta = np.zeros((abs(t_min) + abs(t_max))/dt)
    time = np.arange(t_min, t_max, dt)
    if len(stim.shape) > 1 and stim.shape[1] > 1:
        stim = stim[:,1]
    for i in range(len(times[0])):
        times = np.squeeze(spike_times[0][i])
        for t in times:
            if (int((t + t_min)/dt) < 0) or ((t + t_max)/dt > len(stim)):
                continue;
            
            min_index = int(np.round((t+t_min)/dt))
            max_index = int(np.round((t+t_max)/dt))
            snippet = np.squeeze(stim[ min_index : max_index])            
            sta += snippet
            count += 1
    sta /= count
    return time, sta


def reconstruct_stimulus(spike_times, sta, stimulus, t_max=30., dt=1e-4):
    s_est = np.zeros((spike_times.shape[1], len(stimulus)))
    for i in range(10):
        times = np.squeeze(spike_times[0][i])
        indices = np.asarray((np.round(times/dt)), dtype=int)
        y = np.zeros(len(stimulus))
        y[indices] = 1
        s_est[i, :] = np.convolve(y, sta, mode='same')
    time = np.arange(0, t_max, dt)
    return time, np.mean(s_est, axis=0) 


def plot_results(sta_time, st_average, stim_time, s_est, stimulus, duration, dt):
    sta_ax = plt.subplot2grid((1, 3), (0, 0), rowspan=1, colspan=1)
    stim_ax = plt.subplot2grid((1, 3), (0, 1), rowspan=1, colspan=2)
    
    fig = plt.gcf()
    fig.set_size_inches(15, 5)
    fig.subplots_adjust(left=0.075, bottom=0.12, top=0.92, right=0.975)
    fig.set_facecolor("white")

    sta_ax.plot(sta_time * 1000, st_average, color="dodgerblue", lw=2.)
    sta_ax.set_xlabel("time [ms]", fontsize=12)
    sta_ax.set_ylabel("stimulus", fontsize=12)
    sta_ax.set_xlim([-50, 50])
    #  sta_ax.xaxis.grid('off')
    sta_ax.spines["right"].set_visible(False)
    sta_ax.spines["top"].set_visible(False)
    sta_ax.yaxis.set_ticks_position('left')
    sta_ax.xaxis.set_ticks_position('bottom')
    sta_ax.spines["bottom"].set_linewidth(2.0)
    sta_ax.spines["left"].set_linewidth(2.0)
    sta_ax.tick_params(direction="out", width=2.0)

    ylim = sta_ax.get_ylim()
    xlim = sta_ax.get_xlim()
    sta_ax.plot(list(xlim), [0., 0.], zorder=1, color='darkgray', ls='--', lw=0.75)
    sta_ax.plot([0., 0.], list(ylim), zorder=1, color='darkgray', ls='--', lw=0.75)
    sta_ax.set_xlim(list(xlim))
    sta_ax.set_ylim(list(ylim))
    sta_ax.text(-0.225, 1.05, "A", transform=sta_ax.transAxes, size=14)

    stim_ax.plot(stim_time * 1000, stimulus[:,1], label='stimulus', color='dodgerblue', lw=2.)
    stim_ax.plot(stim_time * 1000, s_est, label='reconstruction', color='red', lw=2)
    stim_ax.set_xlabel('time[ms]', fontsize=12)
    stim_ax.set_xlim([0.0, 250])
    stim_ax.set_ylim([-1., 1.])
    stim_ax.legend()
    stim_ax.plot([0.0, 250], [0., 0.], color="darkgray", lw=0.75, ls='--', zorder=1)
    stim_ax.spines["right"].set_visible(False)
    stim_ax.spines["top"].set_visible(False)
    stim_ax.yaxis.set_ticks_position('left')
    stim_ax.xaxis.set_ticks_position('bottom')
    stim_ax.spines["bottom"].set_linewidth(2.0)
    stim_ax.spines["left"].set_linewidth(2.0)
    stim_ax.tick_params(direction="out", width=2.0)
    stim_ax.text(-0.075, 1.05, "B", transform=stim_ax.transAxes, size=14)
    
    fig.savefig("sta.pdf")
    plt.close()


if __name__ == "__main__":
    punit_data =  scio.loadmat('p-unit_spike_times.mat')
    punit_stim = scio.loadmat('p-unit_stimulus.mat')
    spike_times = punit_data["spike_times"]
    stimulus = punit_stim["stimulus"]
    sta_time, sta = plot_sta(spike_times, stimulus, 5e-5, -0.05, 0.05)
    stim_time, s_est = reconstruct_stimulus(spike_times, sta, stimulus, 10, 5e-5)
    plot_results(sta_time, sta, stim_time, s_est, stimulus, 10, 5e-5)