112 lines
4.5 KiB
Python
112 lines
4.5 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import scipy.io as scio
|
|
from IPython import embed
|
|
|
|
|
|
def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1):
|
|
count = 0
|
|
sta = np.zeros(int((abs(t_min) + abs(t_max))/dt))
|
|
time = np.arange(t_min, t_max, dt)
|
|
if len(stim.shape) > 1 and stim.shape[1] > 1:
|
|
stim = stim[:,1]
|
|
for i in range(len(times[0])):
|
|
times = np.squeeze(spike_times[0][i])
|
|
for t in times:
|
|
if (int((t + t_min)/dt) < 0) or ((t + t_max)/dt > len(stim)):
|
|
continue;
|
|
|
|
min_index = int(np.round((t+t_min)/dt))
|
|
max_index = int(np.round((t+t_max)/dt))
|
|
snippet = np.squeeze(stim[ min_index : max_index])
|
|
sta += snippet
|
|
count += 1
|
|
sta /= count
|
|
return time, sta
|
|
|
|
|
|
def reconstruct_stimulus(spike_times, sta, stimulus, t_max=30., dt=1e-4):
|
|
s_est = np.zeros((spike_times.shape[1], len(stimulus)))
|
|
for i in range(10):
|
|
times = np.squeeze(spike_times[0][i])
|
|
indices = np.asarray((np.round(times/dt)), dtype=int)
|
|
y = np.zeros(len(stimulus))
|
|
y[indices] = 1
|
|
s_est[i, :] = np.convolve(y, sta, mode='same')
|
|
time = np.arange(0, t_max, dt)
|
|
return time, np.mean(s_est, axis=0)
|
|
|
|
|
|
def plot_results(sta_time, st_average, stim_time, s_est, stimulus, duration, dt):
|
|
plt.xkcd()
|
|
sta_ax = plt.subplot2grid((1, 3), (0, 0), rowspan=1, colspan=1)
|
|
stim_ax = plt.subplot2grid((1, 3), (0, 1), rowspan=1, colspan=2)
|
|
|
|
fig = plt.gcf()
|
|
fig.set_size_inches(8, 3)
|
|
fig.subplots_adjust(left=0.08, bottom=0.15, top=0.9, right=0.975)
|
|
fig.set_facecolor("white")
|
|
|
|
sta_ax.plot(sta_time * 1000, st_average, color="#FF9900", lw=2.)
|
|
sta_ax.set_xlabel("Time (ms)")
|
|
sta_ax.set_ylabel("Stimulus")
|
|
sta_ax.set_xlim(-40, 20)
|
|
sta_ax.set_xticks(np.arange(-40, 21, 20))
|
|
sta_ax.set_ylim(-0.1, 0.2)
|
|
sta_ax.set_yticks(np.arange(-0.1, 0.21, 0.1))
|
|
# sta_ax.xaxis.grid('off')
|
|
sta_ax.spines["right"].set_visible(False)
|
|
sta_ax.spines["top"].set_visible(False)
|
|
sta_ax.yaxis.set_ticks_position('left')
|
|
sta_ax.xaxis.set_ticks_position('bottom')
|
|
sta_ax.spines["bottom"].set_linewidth(2.0)
|
|
sta_ax.spines["left"].set_linewidth(2.0)
|
|
sta_ax.tick_params(direction="out", width=2.0)
|
|
ylim = sta_ax.get_ylim()
|
|
xlim = sta_ax.get_xlim()
|
|
sta_ax.plot(list(xlim), [0., 0.], zorder=1, color='darkgray', ls='--', lw=1)
|
|
sta_ax.plot([0., 0.], list(ylim), zorder=1, color='darkgray', ls='--', lw=1)
|
|
sta_ax.set_xlim(list(xlim))
|
|
sta_ax.set_ylim(list(ylim))
|
|
sta_ax.annotate('Time of\nspike',
|
|
xy=(0, 0.18), xycoords='data',
|
|
xytext=(-35, 0.19), textcoords='data', ha='left',
|
|
arrowprops=dict(arrowstyle="->", relpos=(1.0,0.5),
|
|
connectionstyle="angle3,angleA=0,angleB=-70") )
|
|
sta_ax.annotate('STA',
|
|
xy=(-10, 0.05), xycoords='data',
|
|
xytext=(-33, 0.09), textcoords='data', ha='left',
|
|
arrowprops=dict(arrowstyle="->", relpos=(1.0,0.0),
|
|
connectionstyle="angle3,angleA=60,angleB=-40") )
|
|
#sta_ax.text(-0.25, 1.04, "A", transform=sta_ax.transAxes, size=24)
|
|
|
|
stim_ax.plot(stim_time * 1000, stimulus[:,1], label='stimulus', color='#0000FF', lw=2.)
|
|
stim_ax.plot(stim_time * 1000, s_est, label='reconstruction', color='#FF9900', lw=2)
|
|
stim_ax.set_xlabel('Time (ms)')
|
|
stim_ax.set_xlim(0.0, 200)
|
|
stim_ax.set_ylim([-1., 1.])
|
|
stim_ax.legend(loc=(0.3, 0.85), frameon=False, fontsize=12)
|
|
stim_ax.plot([0.0, 250], [0., 0.], color="darkgray", lw=1, ls='--', zorder=1)
|
|
stim_ax.spines["right"].set_visible(False)
|
|
stim_ax.spines["top"].set_visible(False)
|
|
stim_ax.yaxis.set_ticks_position('left')
|
|
stim_ax.xaxis.set_ticks_position('bottom')
|
|
stim_ax.spines["bottom"].set_linewidth(2.0)
|
|
stim_ax.spines["left"].set_linewidth(2.0)
|
|
stim_ax.tick_params(direction="out", width=2.0)
|
|
#stim_ax.text(-0.1, 1.04, "B", transform=stim_ax.transAxes, size=24)
|
|
|
|
fig.tight_layout()
|
|
fig.savefig("sta.pdf")
|
|
plt.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
punit_data = scio.loadmat('p-unit_spike_times.mat')
|
|
punit_stim = scio.loadmat('p-unit_stimulus.mat')
|
|
spike_times = punit_data["spike_times"]
|
|
stimulus = punit_stim["stimulus"]
|
|
sta_time, sta = plot_sta(spike_times, stimulus, 5e-5, -0.05, 0.05)
|
|
stim_time, s_est = reconstruct_stimulus(spike_times, sta, stimulus, 10, 5e-5)
|
|
plot_results(sta_time, sta, stim_time, s_est, stimulus, 10, 5e-5)
|