import numpy as np
import scipy.io as scio
import matplotlib.pyplot as plt
from plotstyle import *


def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1):
    count = 0
    sta = np.zeros(int((abs(t_min) + abs(t_max))/dt))
    time = np.arange(t_min, t_max, dt)
    if len(stim.shape) > 1 and stim.shape[1] > 1:
        stim = stim[:,1]
    for i in range(len(times[0])):
        times = np.squeeze(spike_times[0][i])
        for t in times:
            if (int((t + t_min)/dt) < 0) or ((t + t_max)/dt > len(stim)):
                continue;
            
            min_index = int(np.round((t+t_min)/dt))
            max_index = int(np.round((t+t_max)/dt))
            snippet = np.squeeze(stim[ min_index : max_index])            
            sta += snippet
            count += 1
    sta /= count
    return time, sta


def reconstruct_stimulus(spike_times, sta, stimulus, t_max=30., dt=1e-4):
    s_est = np.zeros((spike_times.shape[1], len(stimulus)))
    for i in range(10):
        times = np.squeeze(spike_times[0][i])
        indices = np.asarray((np.round(times/dt)), dtype=int)
        y = np.zeros(len(stimulus))
        y[indices] = 1
        s_est[i, :] = np.convolve(y, sta, mode='same')
    time = np.arange(0, t_max, dt)
    return time, np.mean(s_est, axis=0) 


def plot_results(sta_time, st_average, stim_time, s_est, stimulus, duration, dt):
    sta_ax = plt.subplot2grid((1, 3), (0, 0), rowspan=1, colspan=1)
    stim_ax = plt.subplot2grid((1, 3), (0, 1), rowspan=1, colspan=2)
    
    fig = plt.gcf()
    fig.subplots_adjust(**adjust_fs(fig, left=6.5, right=2.0, bottom=3.0, top=2.0))

    sta_ax.plot(sta_time * 1000, st_average, color=colors['orange'], lw=2.)
    sta_ax.set_xlabel('Time', 'ms')
    sta_ax.set_ylabel('Stimulus')
    sta_ax.set_xlim(-40, 20)
    sta_ax.set_xticks(np.arange(-40, 21, 20))
    sta_ax.set_ylim(-0.1, 0.2)
    sta_ax.set_yticks(np.arange(-0.1, 0.21, 0.1))
    ylim = sta_ax.get_ylim()
    xlim = sta_ax.get_xlim()
    sta_ax.plot(list(xlim), [0., 0.], zorder=1, **lsGrid)
    sta_ax.plot([0., 0.], list(ylim), zorder=1, **lsGrid)
    sta_ax.set_xlim(list(xlim))
    sta_ax.set_ylim(list(ylim))
    sta_ax.annotate('Time of\nspike',
                    xy=(0, 0.18), xycoords='data',
                    xytext=(-35, 0.19), textcoords='data', ha='left',
                    arrowprops=dict(arrowstyle="->", relpos=(1.0,0.5),
                    connectionstyle="angle3,angleA=0,angleB=-70") )
    sta_ax.annotate('STA',
                    xy=(-10, 0.05), xycoords='data',
                    xytext=(-33, 0.09), textcoords='data', ha='left',
                    arrowprops=dict(arrowstyle="->", relpos=(1.0,0.0),
                    connectionstyle="angle3,angleA=60,angleB=-40") )
    #sta_ax.text(-0.25, 1.04, "A", transform=sta_ax.transAxes, size=24)

    stim_ax.plot(stim_time * 1000, stimulus[:,1], label='stimulus', color=colors['blue'], lw=2.)
    stim_ax.plot(stim_time * 1000, s_est, label='reconstruction', color=colors['orange'], lw=2)
    stim_ax.set_xlabel('Time', 'ms')
    stim_ax.set_xlim(0.0, 200)
    stim_ax.set_ylim([-1., 1.])
    stim_ax.legend(loc=(0.3, 0.85), frameon=False)
    stim_ax.plot([0.0, 250], [0., 0.], zorder=1, **lsGrid)
    #stim_ax.text(-0.1, 1.04, "B", transform=stim_ax.transAxes, size=24)

    fig.savefig("sta.pdf")
    plt.close()


if __name__ == "__main__":
    punit_data =  scio.loadmat('p-unit_spike_times.mat')
    punit_stim = scio.loadmat('p-unit_stimulus.mat')
    spike_times = punit_data["spike_times"]
    stimulus = punit_stim["stimulus"]
    sta_time, sta = plot_sta(spike_times, stimulus, 5e-5, -0.05, 0.05)
    stim_time, s_est = reconstruct_stimulus(spike_times, sta, stimulus, 10, 5e-5)
    plot_results(sta_time, sta, stim_time, s_est, stimulus, 10, 5e-5)