improved plotting, xkcd style methods describing the psth estimation

This commit is contained in:
Jan Grewe 2015-11-02 22:54:07 +01:00
parent cb7924ffd5
commit fffd7226e2
3 changed files with 67 additions and 35 deletions

View File

@ -1,10 +1,9 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import scipy.io as spio import scipy.io as spio
import scipy.stats as spst
import scipy as sp import scipy as sp
import seaborn as sb
from IPython import embed from IPython import embed
sb.set_context("paper")
def set_axis_fontsize(axis, label_size, tick_label_size=None, legend_size=None): def set_axis_fontsize(axis, label_size, tick_label_size=None, legend_size=None):
@ -140,7 +139,7 @@ def plot_bin_rate(spike_times, bin_width, max_t=30, dt=1e-4):
def get_convolved_rate(times, sigma, max_t=30., dt=1.e-4): def get_convolved_rate(times, sigma, max_t=30., dt=1.e-4):
time = np.arange(0., max_t, dt) time = np.arange(0., max_t, dt)
kernel = sp.stats.norm.pdf(np.arange(-8*sigma, 8*sigma, dt),loc=0,scale=sigma) kernel = spst.norm.pdf(np.arange(-8*sigma, 8*sigma, dt),loc=0,scale=sigma)
indices = np.asarray(times/dt, dtype=int) indices = np.asarray(times/dt, dtype=int)
rate = np.zeros(time.shape) rate = np.zeros(time.shape)
rate[indices] = 1.; rate[indices] = 1.;
@ -202,6 +201,24 @@ def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
ax2 = fig.add_subplot(412) ax2 = fig.add_subplot(412)
ax3 = fig.add_subplot(413) ax3 = fig.add_subplot(413)
ax4 = fig.add_subplot(414) ax4 = fig.add_subplot(414)
ax1.spines["right"].set_visible(False)
ax1.spines["top"].set_visible(False)
ax1.yaxis.set_ticks_position('left')
ax1.xaxis.set_ticks_position('bottom')
ax2.spines["right"].set_visible(False)
ax2.spines["top"].set_visible(False)
ax2.yaxis.set_ticks_position('left')
ax2.xaxis.set_ticks_position('bottom')
ax3.spines["right"].set_visible(False)
ax3.spines["top"].set_visible(False)
ax3.yaxis.set_ticks_position('left')
ax3.xaxis.set_ticks_position('bottom')
ax4.spines["right"].set_visible(False)
ax4.spines["top"].set_visible(False)
ax4.yaxis.set_ticks_position('left')
ax4.xaxis.set_ticks_position('bottom')
ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5) ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
ax1.set_ylabel("spikes", fontsize=10) ax1.set_ylabel("spikes", fontsize=10)
@ -211,7 +228,7 @@ def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
set_axis_fontsize(ax1, 10) set_axis_fontsize(ax1, 10)
ax1.set_xticklabels([]) ax1.set_xticklabels([])
ax2.plot(time, inst_rate, label="instantaneous rate") ax2.plot(time, inst_rate, lw=1.5, label="instantaneous rate")
ax2.set_ylabel("firing rate [Hz]", fontsize=10) ax2.set_ylabel("firing rate [Hz]", fontsize=10)
ax2.legend(fontsize=10) ax2.legend(fontsize=10)
ax2.set_xlim([1.5, 3.5]) ax2.set_xlim([1.5, 3.5])
@ -219,7 +236,7 @@ def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
set_axis_fontsize(ax2, 10) set_axis_fontsize(ax2, 10)
ax2.set_xticklabels([]) ax2.set_xticklabels([])
ax3.plot(time, binn_rate, label="binned rate") ax3.plot(time, binn_rate, lw=1.5, label="binned rate")
ax3.set_ylabel("firing rate [Hz]", fontsize=10) ax3.set_ylabel("firing rate [Hz]", fontsize=10)
ax3.legend(fontsize=10) ax3.legend(fontsize=10)
ax3.set_xlim([1.5, 3.5]) ax3.set_xlim([1.5, 3.5])
@ -227,7 +244,7 @@ def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
set_axis_fontsize(ax3, 10) set_axis_fontsize(ax3, 10)
ax3.set_xticklabels([]) ax3.set_xticklabels([])
ax4.plot(time, conv_rate, label="convolved rate") ax4.plot(time, conv_rate, lw=1.5, label="convolved rate")
ax4.set_xlabel("time [s]", fontsize=10) ax4.set_xlabel("time [s]", fontsize=10)
ax4.set_ylabel("firing rate [Hz]", fontsize=10) ax4.set_ylabel("firing rate [Hz]", fontsize=10)
ax4.legend(fontsize=10) ax4.legend(fontsize=10)

View File

@ -55,31 +55,36 @@ def plot_bin_method():
dt = 1e-5 dt = 1e-5
duration = 0.5 duration = 0.5
spike_times = create_spikes(0.025, duration) spike_times = create_spikes(0.018, duration)
t = np.arange(0., duration, dt) t = np.arange(0., duration, dt)
bins = np.arange(0, 0.5, 0.05) bins = np.arange(0, 0.55, 0.05)
count, _ = np.histogram(spike_times, bins) count, _ = np.histogram(spike_times, bins)
plt.xkcd() plt.xkcd()
set_rc()
fig = plt.figure() fig = plt.figure()
fig.set_size_inches(5., 2.5) fig.set_size_inches(5., 2.5)
fig.set_facecolor('white') fig.set_facecolor('white')
spikes = plt.subplot2grid((3,1), (0,0), rowspan=1, colspan=1) spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
rate_ax = plt.subplot2grid((3,1), (1,0), rowspan=2, colspan=1) rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
setup_axis(spikes, rate_ax) setup_axis(spikes, rate_ax)
spikes.set_ylim([0., 1.25])
spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1) spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.25)
spikes.vlines(bins, 0., 1.05, color="red", lw = 2) spikes.vlines(np.hstack((0,bins)), 0., 1.25, color="red", lw=1.5, linestyles='--')
for c in count: for i,c in enumerate(count):
spikes.text(bins-(bins[1]-bins[0])/2, 0.75, str(c), fontdict={"color":"red"}) spikes.text(bins[i] + bins[1]/2, 1.05, str(c), fontdict={'color':'red'})
spikes.set_xlim([0, duration]) spikes.set_xlim([0, duration])
#rate_ax.step(bins[1:], count/0.05) rate = count / 0.05
rate_ax.step(bins, np.hstack((rate, rate[-1])), where='post')
rate_ax.set_xlim([0., duration])
rate_ax.set_ylim([0., 100.])
rate_ax.set_yticks(np.arange(0,105,25))
fig.tight_layout() fig.tight_layout()
plt.show() fig.savefig("../lecture/images/bin_method.pdf")
#fig.savefig("bin_method.pdf") plt.close()
def plot_conv_method(): def plot_conv_method():
@ -94,11 +99,13 @@ def plot_conv_method():
rate = np.convolve(rate, kernel, mode='same') rate = np.convolve(rate, kernel, mode='same')
rate = np.roll(rate, -1) rate = np.roll(rate, -1)
plt.xkcd()
set_rc()
fig = plt.figure() fig = plt.figure()
fig.set_size_inches(5., 2.5) fig.set_size_inches(5., 2.5)
fig.set_facecolor('white') fig.set_facecolor('white')
spikes = plt.subplot2grid((3,1), (0,0), rowspan=1, colspan=1) spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
rate_ax = plt.subplot2grid((3,1), (1,0), rowspan=2, colspan=1) rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
setup_axis(spikes, rate_ax) setup_axis(spikes, rate_ax)
spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.5, zorder=2) spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.5, zorder=2)
@ -112,11 +119,11 @@ def plot_conv_method():
rate_ax.set_ylim([0, 50]) rate_ax.set_ylim([0, 50])
rate_ax.set_yticks(np.arange(0,75,25)) rate_ax.set_yticks(np.arange(0,75,25))
fig.tight_layout() fig.tight_layout()
fig.savefig("conv_method.pdf") fig.savefig("../lecture/images/conv_method.pdf")
def plot_isi_method(): def plot_isi_method():
spike_times = create_spikes(0.08, 0.5) spike_times = create_spikes(0.09, 0.5)
plt.xkcd() plt.xkcd()
set_rc() set_rc()
@ -124,8 +131,8 @@ def plot_isi_method():
fig.set_size_inches(5., 2.5) fig.set_size_inches(5., 2.5)
fig.set_facecolor('white') fig.set_facecolor('white')
spikes = plt.subplot2grid((3,1), (0,0), rowspan=1, colspan=1) spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
rate = plt.subplot2grid((3,1), (1,0), rowspan=2, colspan=1) rate = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
setup_axis(spikes, rate) setup_axis(spikes, rate)
spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.25) spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.25)
@ -133,18 +140,19 @@ def plot_isi_method():
for i in range(1, len(spike_times)): for i in range(1, len(spike_times)):
t_start = spike_times[i-1] t_start = spike_times[i-1]
t = spike_times[i] t = spike_times[i]
spikes.annotate(s='', xy=(t_start, 0.5), xytext=(t,0.5), arrowprops=dict(arrowstyle='<->')) spikes.annotate(s='', xy=(t_start, 0.5), xytext=(t,0.5), arrowprops=dict(arrowstyle='<->'), color='red')
i_rate = 1./np.diff(spike_times) i_rate = 1./np.diff(spike_times)
rate.step(spike_times[1:], i_rate,color="darkblue", lw=1.25, where="pre") rate.step(spike_times, np.hstack((i_rate, i_rate[-1])),color="darkblue", lw=1.25, where="post")
rate.set_ylim([0, 75]) rate.set_ylim([0, 75])
rate.set_yticks(np.arange(0,100,25)) rate.set_yticks(np.arange(0,100,25))
fig.tight_layout() fig.tight_layout()
fig.savefig("isi_method.pdf") fig.savefig("../lecture/images/isi_method.pdf")
if __name__ == '__main__': if __name__ == '__main__':
#plot_isi_method() plot_isi_method()
#plot_conv_method() plot_conv_method()
plot_bin_method() plot_bin_method()

View File

@ -1,9 +1,7 @@
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import scipy.io as scio import scipy.io as scio
import seaborn as sb
from IPython import embed from IPython import embed
sb.set_context("paper")
def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1): def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1):
@ -31,10 +29,15 @@ def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1):
fig.set_facecolor("white") fig.set_facecolor("white")
ax = fig.add_subplot(111) ax = fig.add_subplot(111)
ax.plot(time, sta, color="darkblue") ax.plot(time, sta, color="darkblue", lw=1)
ax.set_xlabel("time [s]") ax.set_xlabel("time [s]")
ax.set_ylabel("stimulus") ax.set_ylabel("stimulus")
ax.xaxis.grid('off') ax.xaxis.grid('off')
ax.spines["right"].set_visible(False)
ax.spines["top"].set_visible(False)
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
ylim = ax.get_ylim() ylim = ax.get_ylim()
xlim = ax.get_xlim() xlim = ax.get_xlim()
ax.plot(list(xlim), [0., 0.], zorder=1, color='darkgray', ls='--') ax.plot(list(xlim), [0., 0.], zorder=1, color='darkgray', ls='--')
@ -56,13 +59,17 @@ def reconstruct_stimulus(spike_times, sta, stimulus, t_max=30., dt=1e-4):
s_est[i, :] = np.convolve(y, sta, mode='same') s_est[i, :] = np.convolve(y, sta, mode='same')
plt.plot(np.arange(0, t_max, dt), stimulus[:,1], label='stimulus', color='darkblue', lw=2.) plt.plot(np.arange(0, t_max, dt), stimulus[:,1], label='stimulus', color='darkblue', lw=2.)
plt.plot(np.arange(0, t_max, dt), np.mean(s_est, axis=0), label='reconstruction', color='silver', lw=1.5) plt.plot(np.arange(0, t_max, dt), np.mean(s_est, axis=0), label='reconstruction', color='gray', lw=1.5)
plt.xlabel('time[s]') plt.xlabel('time[s]')
plt.ylabel('stimulus') plt.ylabel('stimulus')
plt.xlim([0.0, 0.25]) plt.xlim([0.0, 0.25])
plt.ylim([-1., 1.]) plt.ylim([-1., 1.])
plt.legend() plt.legend()
plt.plot([0.0, 0.25], [0., 0.], color="darkgray", lw=1, ls='--', zorder=1) plt.plot([0.0, 0.25], [0., 0.], color="darkgray", lw=1.5, ls='--', zorder=1)
plt.gca().spines["right"].set_visible(False)
plt.gca().spines["top"].set_visible(False)
plt.gca().yaxis.set_ticks_position('left')
plt.gca().xaxis.set_ticks_position('bottom')
fig = plt.gcf() fig = plt.gcf()
fig.set_size_inches(7.5, 5) fig.set_size_inches(7.5, 5)