import numpy as np import scipy.io as scio import matplotlib.pyplot as plt from plotstyle import * def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1): count = 0 sta = np.zeros(int((abs(t_min) + abs(t_max))/dt)) time = np.arange(t_min, t_max, dt) if len(stim.shape) > 1 and stim.shape[1] > 1: stim = stim[:,1] for i in range(len(times[0])): times = np.squeeze(spike_times[0][i]) for t in times: if (int((t + t_min)/dt) < 0) or ((t + t_max)/dt > len(stim)): continue; min_index = int(np.round((t+t_min)/dt)) max_index = int(np.round((t+t_max)/dt)) snippet = np.squeeze(stim[ min_index : max_index]) sta += snippet count += 1 sta /= count return time, sta def reconstruct_stimulus(spike_times, sta, stimulus, t_max=30., dt=1e-4): s_est = np.zeros((spike_times.shape[1], len(stimulus))) for i in range(10): times = np.squeeze(spike_times[0][i]) indices = np.asarray((np.round(times/dt)), dtype=int) y = np.zeros(len(stimulus)) y[indices] = 1 s_est[i, :] = np.convolve(y, sta, mode='same') time = np.arange(0, t_max, dt) return time, np.mean(s_est, axis=0) def plot_results(sta_time, st_average, stim_time, s_est, stimulus, duration, dt): sta_ax = plt.subplot2grid((1, 3), (0, 0), rowspan=1, colspan=1) stim_ax = plt.subplot2grid((1, 3), (0, 1), rowspan=1, colspan=2) fig = plt.gcf() fig.subplots_adjust(**adjust_fs(fig, left=6.5, right=2.0, bottom=3.0, top=2.0)) sta_ax.plot(sta_time * 1000, st_average, color=colors['orange'], lw=2.) sta_ax.set_xlabel('Time', 'ms') sta_ax.set_ylabel('Stimulus') sta_ax.set_xlim(-40, 20) sta_ax.set_xticks(np.arange(-40, 21, 20)) sta_ax.set_ylim(-0.1, 0.2) sta_ax.set_yticks(np.arange(-0.1, 0.21, 0.1)) ylim = sta_ax.get_ylim() xlim = sta_ax.get_xlim() sta_ax.plot(list(xlim), [0., 0.], zorder=1, **lsGrid) sta_ax.plot([0., 0.], list(ylim), zorder=1, **lsGrid) sta_ax.set_xlim(list(xlim)) sta_ax.set_ylim(list(ylim)) sta_ax.annotate('Time of\nspike', xy=(0, 0.18), xycoords='data', xytext=(-35, 0.19), textcoords='data', ha='left', arrowprops=dict(arrowstyle="->", relpos=(1.0,0.5), connectionstyle="angle3,angleA=0,angleB=-70") ) sta_ax.annotate('STA', xy=(-10, 0.05), xycoords='data', xytext=(-33, 0.09), textcoords='data', ha='left', arrowprops=dict(arrowstyle="->", relpos=(1.0,0.0), connectionstyle="angle3,angleA=60,angleB=-40") ) #sta_ax.text(-0.25, 1.04, "A", transform=sta_ax.transAxes, size=24) stim_ax.plot(stim_time * 1000, stimulus[:,1], label='stimulus', color=colors['blue'], lw=2.) stim_ax.plot(stim_time * 1000, s_est, label='reconstruction', color=colors['orange'], lw=2) stim_ax.set_xlabel('Time', 'ms') stim_ax.set_xlim(0.0, 200) stim_ax.set_ylim([-1., 1.]) stim_ax.legend(loc=(0.3, 0.85), frameon=False) stim_ax.plot([0.0, 250], [0., 0.], zorder=1, **lsGrid) #stim_ax.text(-0.1, 1.04, "B", transform=stim_ax.transAxes, size=24) fig.savefig("sta.pdf") plt.close() if __name__ == "__main__": punit_data = scio.loadmat('p-unit_spike_times.mat') punit_stim = scio.loadmat('p-unit_stimulus.mat') spike_times = punit_data["spike_times"] stimulus = punit_stim["stimulus"] sta_time, sta = plot_sta(spike_times, stimulus, 5e-5, -0.05, 0.05) stim_time, s_est = reconstruct_stimulus(spike_times, sta, stimulus, 10, 5e-5) plot_results(sta_time, sta, stim_time, s_est, stimulus, 10, 5e-5)