Added missing files for spike_trains
This commit is contained in:
@@ -1,268 +0,0 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.io as spio
|
||||
import scipy.stats as spst
|
||||
import scipy as sp
|
||||
from IPython import embed
|
||||
|
||||
|
||||
def set_axis_fontsize(axis, label_size, tick_label_size=None, legend_size=None):
|
||||
"""
|
||||
Sets axis, tick label and legend font sizes to the desired size.
|
||||
|
||||
:param axis: the axes object
|
||||
:param label_size: the size of the axis label
|
||||
:param tick_label_size: the size of the tick labels. If None, lable_size is used
|
||||
:param legend_size: the size of the font used in the legend.If None, the label_size is used.
|
||||
|
||||
"""
|
||||
if not tick_label_size:
|
||||
tick_label_size = label_size
|
||||
if not legend_size:
|
||||
legend_size = label_size
|
||||
axis.xaxis.get_label().set_fontsize(label_size)
|
||||
axis.yaxis.get_label().set_fontsize(label_size)
|
||||
for tick in axis.xaxis.get_major_ticks() + axis.yaxis.get_major_ticks():
|
||||
tick.label.set_fontsize(tick_label_size)
|
||||
|
||||
l = axis.get_legend()
|
||||
if l:
|
||||
for t in l.get_texts():
|
||||
t.set_fontsize(legend_size)
|
||||
|
||||
|
||||
def get_instantaneous_rate(times, max_t=30., dt=1e-4):
|
||||
time = np.arange(0., max_t, dt)
|
||||
indices = np.asarray(times / dt, dtype=int)
|
||||
intervals = np.diff(np.hstack(([0], times)))
|
||||
inst_rate = np.zeros(time.shape)
|
||||
|
||||
for i, index in enumerate(indices[1:]):
|
||||
inst_rate[indices[i-1]:indices[i]] = 1/intervals[i]
|
||||
return time, inst_rate
|
||||
|
||||
|
||||
def plot_isi_rate(spike_times, max_t=30, dt=1e-4):
|
||||
times = np.squeeze(spike_times[0][0])[:50000]
|
||||
time, rate = get_instantaneous_rate(times, max_t=50000*dt)
|
||||
|
||||
rates = np.zeros((len(rate), len(spike_times)))
|
||||
for i in range(len(spike_times)):
|
||||
_, rates[:, i] = get_instantaneous_rate(np.squeeze(spike_times[i][0])[:50000],
|
||||
max_t=50000*dt)
|
||||
avg_rate = np.mean(rates, axis=1)
|
||||
rate_std = np.std(rates, axis=1)
|
||||
|
||||
fig = plt.figure()
|
||||
ax1 = fig.add_subplot(311)
|
||||
ax2 = fig.add_subplot(312)
|
||||
ax3 = fig.add_subplot(313)
|
||||
|
||||
ax1.vlines(times[times < (50000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
|
||||
ax1.set_ylabel("skpikes", fontsize=12)
|
||||
set_axis_fontsize(ax1, 12)
|
||||
ax1.set_xlim([0, 5])
|
||||
ax2.plot(time, rate, label="instantaneous rate, trial 1")
|
||||
ax2.set_ylabel("firing rate [Hz]", fontsize=12)
|
||||
ax2.legend(fontsize=12)
|
||||
set_axis_fontsize(ax2, 12)
|
||||
|
||||
ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
|
||||
alpha=0.5, label="standard deviation")
|
||||
ax3.plot(time, avg_rate, label="average rate")
|
||||
ax3.set_xlabel("times [s]", fontsize=12)
|
||||
ax3.set_ylabel("firing rate [Hz]", fontsize=12)
|
||||
ax3.legend(fontsize=12)
|
||||
ax3.set_ylim([0, 450])
|
||||
set_axis_fontsize(ax3, 12)
|
||||
|
||||
fig.set_size_inches(15, 10)
|
||||
fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
|
||||
fig.set_facecolor("white")
|
||||
fig.savefig("../lectures/images/instantaneous_rate.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
def get_binned_rate(times, bin_width=0.05, max_t=30., dt=1e-4):
|
||||
time = np.arange(0., max_t, dt)
|
||||
bins = np.arange(0., max_t, bin_width)
|
||||
bin_indices = bins / dt
|
||||
hist, _ = sp.histogram(times, bins)
|
||||
rate = np.zeros(time.shape)
|
||||
|
||||
for i, b in enumerate(bin_indices[1:]):
|
||||
rate[bin_indices[i-1]:b] = hist[i-1]/bin_width
|
||||
return time, rate
|
||||
|
||||
|
||||
def plot_bin_rate(spike_times, bin_width, max_t=30, dt=1e-4):
|
||||
times = np.squeeze(spike_times[0][0])
|
||||
time, rate = get_binned_rate(times)
|
||||
rates = np.zeros((len(rate), len(spike_times)))
|
||||
for i in range(len(spike_times)):
|
||||
_, rates[:, i] = get_binned_rate(np.squeeze(spike_times[i][0]))
|
||||
avg_rate = np.mean(rates, axis=1)
|
||||
rate_std = np.std(rates, axis=1)
|
||||
|
||||
fig = plt.figure()
|
||||
ax1 = fig.add_subplot(311)
|
||||
ax2 = fig.add_subplot(312)
|
||||
ax3 = fig.add_subplot(313)
|
||||
|
||||
ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
|
||||
ax1.set_ylabel("skpikes", fontsize=12)
|
||||
ax1.set_xlim([0, 5])
|
||||
set_axis_fontsize(ax1, 12)
|
||||
|
||||
ax2.plot(time, rate, label="binned rate, trial 1")
|
||||
ax2.set_ylabel("firing rate [Hz]", fontsize=12)
|
||||
ax2.legend(fontsize=12)
|
||||
ax2.set_xlim([0, 5])
|
||||
set_axis_fontsize(ax2, 12)
|
||||
|
||||
ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
|
||||
alpha=0.5, label="standard deviation")
|
||||
ax3.plot(time, avg_rate, label="average rate")
|
||||
ax3.set_xlabel("times [s]", fontsize=12)
|
||||
ax3.set_ylabel("firing rate [Hz]", fontsize=12)
|
||||
ax3.legend(fontsize=12)
|
||||
ax3.set_xlim([0, 5])
|
||||
ax3.set_ylim([0, 450])
|
||||
set_axis_fontsize(ax3, 12)
|
||||
|
||||
fig.set_size_inches(15, 10)
|
||||
fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
|
||||
fig.set_facecolor("white")
|
||||
fig.savefig("../lectures/images/binned_rate.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
def get_convolved_rate(times, sigma, max_t=30., dt=1.e-4):
|
||||
time = np.arange(0., max_t, dt)
|
||||
kernel = spst.norm.pdf(np.arange(-8*sigma, 8*sigma, dt),loc=0,scale=sigma)
|
||||
indices = np.asarray(times/dt, dtype=int)
|
||||
rate = np.zeros(time.shape)
|
||||
rate[indices] = 1.;
|
||||
conv_rate = np.convolve(rate, kernel, mode="same")
|
||||
return time, conv_rate
|
||||
|
||||
|
||||
def plot_conv_rate(spike_times, sigma=0.05, max_t=30, dt=1e-4):
|
||||
times = np.squeeze(spike_times[0][0])
|
||||
time, rate = get_convolved_rate(times, sigma)
|
||||
|
||||
rates = np.zeros((len(rate), len(spike_times)))
|
||||
for i in range(len(spike_times)):
|
||||
_, rates[:, i] = get_convolved_rate(np.squeeze(spike_times[i][0]), sigma)
|
||||
avg_rate = np.mean(rates, axis=1)
|
||||
rate_std = np.std(rates, axis=1)
|
||||
|
||||
fig = plt.figure()
|
||||
ax1 = fig.add_subplot(311)
|
||||
ax2 = fig.add_subplot(312)
|
||||
ax3 = fig.add_subplot(313)
|
||||
|
||||
ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
|
||||
ax1.set_ylabel("skpikes", fontsize=12)
|
||||
ax1.set_xlim([0, 5])
|
||||
set_axis_fontsize(ax1, 12)
|
||||
|
||||
ax2.plot(time, rate, label="convolved rate, trial 1")
|
||||
ax2.set_ylabel("firing rate [Hz]", fontsize=12)
|
||||
ax2.legend(fontsize=12)
|
||||
ax2.set_xlim([0, 5])
|
||||
set_axis_fontsize(ax2, 12)
|
||||
|
||||
ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
|
||||
alpha=0.5, label="standard deviation")
|
||||
ax3.plot(time, avg_rate, label="average rate")
|
||||
ax3.set_xlabel("times [s]", fontsize=12)
|
||||
ax3.set_ylabel("firing rate [Hz]", fontsize=12)
|
||||
ax3.legend(fontsize=12)
|
||||
ax3.set_xlim([0, 5])
|
||||
ax3.set_ylim([0, 450])
|
||||
set_axis_fontsize(ax3, 12)
|
||||
|
||||
fig.set_size_inches(15, 10)
|
||||
fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
|
||||
fig.set_facecolor("white")
|
||||
fig.savefig("../lectures/images/convolved_rate.png")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
|
||||
times = np.squeeze(spike_times[0][0])
|
||||
time, conv_rate = get_convolved_rate(times, bin_width/np.sqrt(12.))
|
||||
time, inst_rate = get_instantaneous_rate(times)
|
||||
time, binn_rate = get_binned_rate(times, bin_width)
|
||||
|
||||
fig = plt.figure()
|
||||
ax1 = fig.add_subplot(411)
|
||||
ax2 = fig.add_subplot(412)
|
||||
ax3 = fig.add_subplot(413)
|
||||
ax4 = fig.add_subplot(414)
|
||||
ax1.spines["right"].set_visible(False)
|
||||
ax1.spines["top"].set_visible(False)
|
||||
ax1.yaxis.set_ticks_position('left')
|
||||
ax1.xaxis.set_ticks_position('bottom')
|
||||
ax2.spines["right"].set_visible(False)
|
||||
ax2.spines["top"].set_visible(False)
|
||||
ax2.yaxis.set_ticks_position('left')
|
||||
ax2.xaxis.set_ticks_position('bottom')
|
||||
ax3.spines["right"].set_visible(False)
|
||||
ax3.spines["top"].set_visible(False)
|
||||
ax3.yaxis.set_ticks_position('left')
|
||||
ax3.xaxis.set_ticks_position('bottom')
|
||||
ax4.spines["right"].set_visible(False)
|
||||
ax4.spines["top"].set_visible(False)
|
||||
ax4.yaxis.set_ticks_position('left')
|
||||
ax4.xaxis.set_ticks_position('bottom')
|
||||
|
||||
|
||||
|
||||
ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
|
||||
ax1.set_ylabel("spikes", fontsize=10)
|
||||
ax1.set_xlim([1.5, 3.5])
|
||||
ax1.set_ylim([0, 1])
|
||||
ax1.set_yticks([0, 1])
|
||||
set_axis_fontsize(ax1, 10)
|
||||
ax1.set_xticklabels([])
|
||||
|
||||
ax2.plot(time, inst_rate, lw=1.5, label="instantaneous rate")
|
||||
ax2.set_ylabel("firing rate [Hz]", fontsize=10)
|
||||
ax2.legend(fontsize=10)
|
||||
ax2.set_xlim([1.5, 3.5])
|
||||
ax2.set_ylim([0, 300])
|
||||
set_axis_fontsize(ax2, 10)
|
||||
ax2.set_xticklabels([])
|
||||
|
||||
ax3.plot(time, binn_rate, lw=1.5, label="binned rate")
|
||||
ax3.set_ylabel("firing rate [Hz]", fontsize=10)
|
||||
ax3.legend(fontsize=10)
|
||||
ax3.set_xlim([1.5, 3.5])
|
||||
ax3.set_ylim([0, 300])
|
||||
set_axis_fontsize(ax3, 10)
|
||||
ax3.set_xticklabels([])
|
||||
|
||||
ax4.plot(time, conv_rate, lw=1.5, label="convolved rate")
|
||||
ax4.set_xlabel("time [s]", fontsize=10)
|
||||
ax4.set_ylabel("firing rate [Hz]", fontsize=10)
|
||||
ax4.legend(fontsize=10)
|
||||
ax4.set_xlim([1.5, 3.5])
|
||||
ax4.set_ylim([0, 300])
|
||||
set_axis_fontsize(ax4, 10)
|
||||
|
||||
fig.set_size_inches(7.5, 5)
|
||||
fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95, )
|
||||
fig.set_facecolor("white")
|
||||
fig.savefig("../lecture/images/psth_comparison.pdf")
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
spike_times = spio.loadmat('lifoustim.mat')["spikes"]
|
||||
# plot_isi_rate(spike_times)
|
||||
# plot_bin_rate(spike_times, 0.05)
|
||||
# plot_conv_rate(spike_times, 0.025)
|
||||
plot_comparison(spike_times, 0.05, 0.025)
|
||||
|
||||
@@ -1,158 +0,0 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
from IPython import embed
|
||||
|
||||
def set_rc():
|
||||
plt.rcParams['xtick.labelsize'] = 8
|
||||
plt.rcParams['ytick.labelsize'] = 8
|
||||
plt.rcParams['xtick.major.size'] = 5
|
||||
plt.rcParams['xtick.minor.size'] = 5
|
||||
plt.rcParams['xtick.major.width'] = 2
|
||||
plt.rcParams['xtick.minor.width'] = 2
|
||||
plt.rcParams['ytick.major.size'] = 5
|
||||
plt.rcParams['ytick.minor.size'] = 5
|
||||
plt.rcParams['ytick.major.width'] = 2
|
||||
plt.rcParams['ytick.minor.width'] = 2
|
||||
plt.rcParams['xtick.direction'] = "out"
|
||||
plt.rcParams['ytick.direction'] = "out"
|
||||
|
||||
|
||||
def create_spikes(isi=0.08, duration=0.5):
|
||||
times = np.arange(0., duration, isi)
|
||||
times += np.random.randn(len(times)) * (isi / 2.5)
|
||||
times = np.delete(times, np.nonzero(times < 0))
|
||||
times = np.delete(times, np.nonzero(times > duration))
|
||||
times = np.sort(times)
|
||||
return times
|
||||
|
||||
|
||||
def gaussian(sigma, dt):
|
||||
x = np.arange(-4*sigma, 4*sigma, dt)
|
||||
y = np.exp(-0.5 * (x / sigma)**2)/np.sqrt(2*np.pi)/sigma;
|
||||
return x, y
|
||||
|
||||
|
||||
def setup_axis(spikes_ax, rate_ax):
|
||||
spikes_ax.spines["right"].set_visible(False)
|
||||
spikes_ax.spines["top"].set_visible(False)
|
||||
spikes_ax.yaxis.set_ticks_position('left')
|
||||
spikes_ax.xaxis.set_ticks_position('bottom')
|
||||
spikes_ax.set_yticks([0, 1.0])
|
||||
spikes_ax.set_ylim([0, 1.05])
|
||||
spikes_ax.set_ylabel("spikes", fontsize=9)
|
||||
spikes_ax.text(-0.125, 1.2, "A", transform=spikes_ax.transAxes, size=10)
|
||||
|
||||
rate_ax.spines["right"].set_visible(False)
|
||||
rate_ax.spines["top"].set_visible(False)
|
||||
rate_ax.yaxis.set_ticks_position('left')
|
||||
rate_ax.xaxis.set_ticks_position('bottom')
|
||||
rate_ax.set_xlabel('time[s]', fontsize=9)
|
||||
rate_ax.set_ylabel('firing rate [Hz]', fontsize=9)
|
||||
rate_ax.text(-0.125, 1.15, "B", transform=rate_ax.transAxes, size=10)
|
||||
|
||||
|
||||
def plot_bin_method():
|
||||
dt = 1e-5
|
||||
duration = 0.5
|
||||
|
||||
spike_times = create_spikes(0.018, duration)
|
||||
t = np.arange(0., duration, dt)
|
||||
|
||||
bins = np.arange(0, 0.55, 0.05)
|
||||
count, _ = np.histogram(spike_times, bins)
|
||||
|
||||
plt.xkcd()
|
||||
set_rc()
|
||||
fig = plt.figure()
|
||||
fig.set_size_inches(5., 2.5)
|
||||
fig.set_facecolor('white')
|
||||
spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
|
||||
rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
|
||||
setup_axis(spikes, rate_ax)
|
||||
spikes.set_ylim([0., 1.25])
|
||||
|
||||
spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.25)
|
||||
spikes.vlines(np.hstack((0,bins)), 0., 1.25, color="red", lw=1.5, linestyles='--')
|
||||
for i,c in enumerate(count):
|
||||
spikes.text(bins[i] + bins[1]/2, 1.05, str(c), fontdict={'color':'red'})
|
||||
spikes.set_xlim([0, duration])
|
||||
|
||||
rate = count / 0.05
|
||||
rate_ax.step(bins, np.hstack((rate, rate[-1])), where='post')
|
||||
rate_ax.set_xlim([0., duration])
|
||||
rate_ax.set_ylim([0., 100.])
|
||||
rate_ax.set_yticks(np.arange(0,105,25))
|
||||
fig.tight_layout()
|
||||
fig.savefig("../lecture/images/bin_method.pdf")
|
||||
plt.close()
|
||||
|
||||
|
||||
def plot_conv_method():
|
||||
dt = 1e-5
|
||||
duration = 0.5
|
||||
spike_times = create_spikes(0.05, duration)
|
||||
kernel_time, kernel = gaussian(0.02, dt)
|
||||
|
||||
t = np.arange(0., duration, dt)
|
||||
rate = np.zeros(t.shape)
|
||||
rate[np.asarray(np.round(spike_times/dt), dtype=int)] = 1
|
||||
rate = np.convolve(rate, kernel, mode='same')
|
||||
rate = np.roll(rate, -1)
|
||||
|
||||
plt.xkcd()
|
||||
set_rc()
|
||||
fig = plt.figure()
|
||||
fig.set_size_inches(5., 2.5)
|
||||
fig.set_facecolor('white')
|
||||
spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
|
||||
rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
|
||||
setup_axis(spikes, rate_ax)
|
||||
|
||||
spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.5, zorder=2)
|
||||
for i in spike_times:
|
||||
spikes.plot(kernel_time + i, kernel/np.max(kernel), color="orange", lw=0.75, zorder=1)
|
||||
spikes.set_xlim([0, duration])
|
||||
|
||||
rate_ax.plot(t, rate, color="darkblue", lw=1, zorder=2)
|
||||
rate_ax.fill_between(t, rate, np.zeros(len(rate)), color="red", alpha=0.5)
|
||||
rate_ax.set_xlim([0, duration])
|
||||
rate_ax.set_ylim([0, 50])
|
||||
rate_ax.set_yticks(np.arange(0,75,25))
|
||||
fig.tight_layout()
|
||||
fig.savefig("../lecture/images/conv_method.pdf")
|
||||
|
||||
|
||||
def plot_isi_method():
|
||||
spike_times = create_spikes(0.09, 0.5)
|
||||
|
||||
plt.xkcd()
|
||||
set_rc()
|
||||
fig = plt.figure()
|
||||
fig.set_size_inches(5., 2.5)
|
||||
fig.set_facecolor('white')
|
||||
|
||||
spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
|
||||
rate = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
|
||||
setup_axis(spikes, rate)
|
||||
|
||||
spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.25)
|
||||
spike_times = np.hstack((0, spike_times))
|
||||
for i in range(1, len(spike_times)):
|
||||
t_start = spike_times[i-1]
|
||||
t = spike_times[i]
|
||||
spikes.annotate(s='', xy=(t_start, 0.5), xytext=(t,0.5), arrowprops=dict(arrowstyle='<->'), color='red')
|
||||
|
||||
i_rate = 1./np.diff(spike_times)
|
||||
|
||||
rate.step(spike_times, np.hstack((i_rate, i_rate[-1])),color="darkblue", lw=1.25, where="post")
|
||||
rate.set_ylim([0, 75])
|
||||
rate.set_yticks(np.arange(0,100,25))
|
||||
|
||||
fig.tight_layout()
|
||||
fig.savefig("../lecture/images/isi_method.pdf")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plot_isi_method()
|
||||
plot_conv_method()
|
||||
plot_bin_method()
|
||||
@@ -1,88 +0,0 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import scipy.io as scio
|
||||
from IPython import embed
|
||||
|
||||
|
||||
def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1):
|
||||
count = 0
|
||||
sta = np.zeros((abs(t_min) + abs(t_max))/dt)
|
||||
time = np.arange(t_min, t_max, dt)
|
||||
if len(stim.shape) > 1 and stim.shape[1] > 1:
|
||||
stim = stim[:,1]
|
||||
for i in range(len(times[0])):
|
||||
times = np.squeeze(spike_times[0][i])
|
||||
for t in times:
|
||||
if (int((t + t_min)/dt) < 0) or ((t + t_max)/dt > len(stim)):
|
||||
continue;
|
||||
|
||||
min_index = int(np.round((t+t_min)/dt))
|
||||
max_index = int(np.round((t+t_max)/dt))
|
||||
snippet = np.squeeze(stim[ min_index : max_index])
|
||||
sta += snippet
|
||||
count += 1
|
||||
sta /= count
|
||||
|
||||
fig = plt.figure()
|
||||
fig.set_size_inches(5, 5)
|
||||
fig.subplots_adjust(left=0.15, bottom=0.125, top=0.95, right=0.95, )
|
||||
fig.set_facecolor("white")
|
||||
|
||||
ax = fig.add_subplot(111)
|
||||
ax.plot(time, sta, color="darkblue", lw=1)
|
||||
ax.set_xlabel("time [s]")
|
||||
ax.set_ylabel("stimulus")
|
||||
ax.xaxis.grid('off')
|
||||
ax.spines["right"].set_visible(False)
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.yaxis.set_ticks_position('left')
|
||||
ax.xaxis.set_ticks_position('bottom')
|
||||
|
||||
ylim = ax.get_ylim()
|
||||
xlim = ax.get_xlim()
|
||||
ax.plot(list(xlim), [0., 0.], zorder=1, color='darkgray', ls='--')
|
||||
ax.plot([0., 0.], list(ylim), zorder=1, color='darkgray', ls='--')
|
||||
ax.set_xlim(list(xlim))
|
||||
ax.set_ylim(list(ylim))
|
||||
fig.savefig("../lecture/images/sta.pdf")
|
||||
plt.close()
|
||||
return sta
|
||||
|
||||
|
||||
def reconstruct_stimulus(spike_times, sta, stimulus, t_max=30., dt=1e-4):
|
||||
s_est = np.zeros((spike_times.shape[1], len(stimulus)))
|
||||
for i in range(10):
|
||||
times = np.squeeze(spike_times[0][i])
|
||||
indices = np.asarray((np.round(times/dt)), dtype=int)
|
||||
y = np.zeros(len(stimulus))
|
||||
y[indices] = 1
|
||||
s_est[i, :] = np.convolve(y, sta, mode='same')
|
||||
|
||||
plt.plot(np.arange(0, t_max, dt), stimulus[:,1], label='stimulus', color='darkblue', lw=2.)
|
||||
plt.plot(np.arange(0, t_max, dt), np.mean(s_est, axis=0), label='reconstruction', color='gray', lw=1.5)
|
||||
plt.xlabel('time[s]')
|
||||
plt.ylabel('stimulus')
|
||||
plt.xlim([0.0, 0.25])
|
||||
plt.ylim([-1., 1.])
|
||||
plt.legend()
|
||||
plt.plot([0.0, 0.25], [0., 0.], color="darkgray", lw=1.5, ls='--', zorder=1)
|
||||
plt.gca().spines["right"].set_visible(False)
|
||||
plt.gca().spines["top"].set_visible(False)
|
||||
plt.gca().yaxis.set_ticks_position('left')
|
||||
plt.gca().xaxis.set_ticks_position('bottom')
|
||||
|
||||
fig = plt.gcf()
|
||||
fig.set_size_inches(7.5, 5)
|
||||
fig.subplots_adjust(left=0.15, bottom=0.125, top=0.95, right=0.95, )
|
||||
fig.set_facecolor("white")
|
||||
fig.savefig('../lecture/images/reconstruction.pdf')
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
punit_data = scio.loadmat('../../programming/exercises/p-unit_spike_times.mat')
|
||||
punit_stim = scio.loadmat('../../programming/exercises/p-unit_stimulus.mat')
|
||||
spike_times = punit_data["spike_times"]
|
||||
stimulus = punit_stim["stimulus"]
|
||||
sta = plot_sta(spike_times, stimulus, 5e-5, -0.05, 0.05)
|
||||
reconstruct_stimulus(spike_times, sta, stimulus, 10, 5e-5)
|
||||
Reference in New Issue
Block a user