From fffd7226e2545b70a56e7e0e7c9a794f5232a2d8 Mon Sep 17 00:00:00 2001
From: Jan Grewe <jan.grewe@g-node.org>
Date: Mon, 2 Nov 2015 22:54:07 +0100
Subject: [PATCH] improved plotting, xkcd style methods describing the psth
 estimation

---
 spike_trains/code/firing_rates.py | 29 ++++++++++++----
 spike_trains/code/isi_method.py   | 56 ++++++++++++++++++-------------
 spike_trains/code/sta.py          | 17 +++++++---
 3 files changed, 67 insertions(+), 35 deletions(-)

diff --git a/spike_trains/code/firing_rates.py b/spike_trains/code/firing_rates.py
index 097250e..faaeb29 100644
--- a/spike_trains/code/firing_rates.py
+++ b/spike_trains/code/firing_rates.py
@@ -1,10 +1,9 @@
 import numpy as np
 import matplotlib.pyplot as plt
 import scipy.io as spio
+import scipy.stats as spst
 import scipy as sp
-import seaborn as sb
 from IPython import embed
-sb.set_context("paper")
 
 
 def set_axis_fontsize(axis, label_size, tick_label_size=None, legend_size=None):
@@ -140,7 +139,7 @@ def plot_bin_rate(spike_times, bin_width, max_t=30, dt=1e-4):
 
 def get_convolved_rate(times, sigma, max_t=30., dt=1.e-4):
     time = np.arange(0., max_t, dt)
-    kernel = sp.stats.norm.pdf(np.arange(-8*sigma, 8*sigma, dt),loc=0,scale=sigma)
+    kernel = spst.norm.pdf(np.arange(-8*sigma, 8*sigma, dt),loc=0,scale=sigma)
     indices = np.asarray(times/dt, dtype=int)
     rate = np.zeros(time.shape)
     rate[indices] = 1.;
@@ -202,6 +201,24 @@ def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
     ax2 = fig.add_subplot(412)
     ax3 = fig.add_subplot(413)
     ax4 = fig.add_subplot(414)
+    ax1.spines["right"].set_visible(False)
+    ax1.spines["top"].set_visible(False)
+    ax1.yaxis.set_ticks_position('left')
+    ax1.xaxis.set_ticks_position('bottom')
+    ax2.spines["right"].set_visible(False)
+    ax2.spines["top"].set_visible(False)
+    ax2.yaxis.set_ticks_position('left')
+    ax2.xaxis.set_ticks_position('bottom')
+    ax3.spines["right"].set_visible(False)
+    ax3.spines["top"].set_visible(False)
+    ax3.yaxis.set_ticks_position('left')
+    ax3.xaxis.set_ticks_position('bottom')
+    ax4.spines["right"].set_visible(False)
+    ax4.spines["top"].set_visible(False)
+    ax4.yaxis.set_ticks_position('left')
+    ax4.xaxis.set_ticks_position('bottom')
+
+
 
     ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
     ax1.set_ylabel("spikes", fontsize=10)
@@ -211,7 +228,7 @@ def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
     set_axis_fontsize(ax1, 10)
     ax1.set_xticklabels([])
 
-    ax2.plot(time, inst_rate, label="instantaneous rate")
+    ax2.plot(time, inst_rate, lw=1.5, label="instantaneous rate")
     ax2.set_ylabel("firing rate [Hz]", fontsize=10)
     ax2.legend(fontsize=10)
     ax2.set_xlim([1.5, 3.5])
@@ -219,7 +236,7 @@ def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
     set_axis_fontsize(ax2, 10)
     ax2.set_xticklabels([])
 
-    ax3.plot(time, binn_rate, label="binned rate")
+    ax3.plot(time, binn_rate, lw=1.5, label="binned rate")
     ax3.set_ylabel("firing rate [Hz]", fontsize=10)
     ax3.legend(fontsize=10)
     ax3.set_xlim([1.5, 3.5])
@@ -227,7 +244,7 @@ def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
     set_axis_fontsize(ax3, 10)
     ax3.set_xticklabels([])
 
-    ax4.plot(time, conv_rate, label="convolved rate")
+    ax4.plot(time, conv_rate, lw=1.5, label="convolved rate")
     ax4.set_xlabel("time [s]", fontsize=10)
     ax4.set_ylabel("firing rate [Hz]", fontsize=10)
     ax4.legend(fontsize=10)
diff --git a/spike_trains/code/isi_method.py b/spike_trains/code/isi_method.py
index 776c68e..e4b95e3 100644
--- a/spike_trains/code/isi_method.py
+++ b/spike_trains/code/isi_method.py
@@ -55,31 +55,36 @@ def plot_bin_method():
     dt = 1e-5
     duration = 0.5
     
-    spike_times = create_spikes(0.025, duration)
+    spike_times = create_spikes(0.018, duration)
     t = np.arange(0., duration, dt)
     
-    bins = np.arange(0, 0.5, 0.05)
+    bins = np.arange(0, 0.55, 0.05)
     count, _ = np.histogram(spike_times, bins)
     
     plt.xkcd()
+    set_rc()
     fig = plt.figure()
     fig.set_size_inches(5., 2.5)
     fig.set_facecolor('white')
-    spikes = plt.subplot2grid((3,1), (0,0), rowspan=1, colspan=1)
-    rate_ax = plt.subplot2grid((3,1), (1,0), rowspan=2, colspan=1)
+    spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
+    rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
     setup_axis(spikes, rate_ax)
+    spikes.set_ylim([0., 1.25])
 
-    spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1)    
-    spikes.vlines(bins, 0., 1.05, color="red", lw = 2)
-    for c in count:
-        spikes.text(bins-(bins[1]-bins[0])/2, 0.75, str(c), fontdict={"color":"red"})
+    spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.25)    
+    spikes.vlines(np.hstack((0,bins)), 0., 1.25, color="red", lw=1.5, linestyles='--')
+    for i,c in enumerate(count):
+        spikes.text(bins[i] + bins[1]/2, 1.05, str(c), fontdict={'color':'red'})
     spikes.set_xlim([0, duration])
-    
-    #rate_ax.step(bins[1:], count/0.05)
-    
+
+    rate = count / 0.05
+    rate_ax.step(bins, np.hstack((rate, rate[-1])), where='post')
+    rate_ax.set_xlim([0., duration])
+    rate_ax.set_ylim([0., 100.])
+    rate_ax.set_yticks(np.arange(0,105,25))
     fig.tight_layout()
-    plt.show()
-    #fig.savefig("bin_method.pdf")
+    fig.savefig("../lecture/images/bin_method.pdf")
+    plt.close()
 
 
 def plot_conv_method():
@@ -94,11 +99,13 @@ def plot_conv_method():
     rate = np.convolve(rate, kernel, mode='same')
     rate = np.roll(rate, -1)
 
+    plt.xkcd()
+    set_rc()
     fig = plt.figure()
     fig.set_size_inches(5., 2.5)
     fig.set_facecolor('white')
-    spikes = plt.subplot2grid((3,1), (0,0), rowspan=1, colspan=1)
-    rate_ax = plt.subplot2grid((3,1), (1,0), rowspan=2, colspan=1)
+    spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
+    rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
     setup_axis(spikes, rate_ax)
     
     spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.5, zorder=2)
@@ -112,11 +119,11 @@ def plot_conv_method():
     rate_ax.set_ylim([0, 50])
     rate_ax.set_yticks(np.arange(0,75,25))
     fig.tight_layout()
-    fig.savefig("conv_method.pdf")
+    fig.savefig("../lecture/images/conv_method.pdf")
 
 
 def plot_isi_method():
-    spike_times = create_spikes(0.08, 0.5)
+    spike_times = create_spikes(0.09, 0.5)
     
     plt.xkcd()
     set_rc()
@@ -124,8 +131,8 @@ def plot_isi_method():
     fig.set_size_inches(5., 2.5)
     fig.set_facecolor('white')
     
-    spikes = plt.subplot2grid((3,1), (0,0), rowspan=1, colspan=1)
-    rate = plt.subplot2grid((3,1), (1,0), rowspan=2, colspan=1)
+    spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
+    rate = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
     setup_axis(spikes, rate)
     
     spikes.vlines(spike_times, 0., 1., color="darkblue", lw=1.25)
@@ -133,18 +140,19 @@ def plot_isi_method():
     for i in range(1, len(spike_times)):
         t_start = spike_times[i-1]
         t = spike_times[i]
-        spikes.annotate(s='', xy=(t_start, 0.5), xytext=(t,0.5), arrowprops=dict(arrowstyle='<->'))
+        spikes.annotate(s='', xy=(t_start, 0.5), xytext=(t,0.5), arrowprops=dict(arrowstyle='<->'), color='red')
+
     i_rate = 1./np.diff(spike_times)
 
-    rate.step(spike_times[1:], i_rate,color="darkblue", lw=1.25, where="pre")
+    rate.step(spike_times, np.hstack((i_rate, i_rate[-1])),color="darkblue", lw=1.25, where="post")
     rate.set_ylim([0, 75])
     rate.set_yticks(np.arange(0,100,25))
     
     fig.tight_layout()
-    fig.savefig("isi_method.pdf")
+    fig.savefig("../lecture/images/isi_method.pdf")
     
     
 if __name__ == '__main__':
-    #plot_isi_method()
-    #plot_conv_method()
+    plot_isi_method()
+    plot_conv_method()
     plot_bin_method()
diff --git a/spike_trains/code/sta.py b/spike_trains/code/sta.py
index 5b14083..d373820 100644
--- a/spike_trains/code/sta.py
+++ b/spike_trains/code/sta.py
@@ -1,9 +1,7 @@
 import numpy as np
 import matplotlib.pyplot as plt
 import scipy.io as scio
-import seaborn as sb
 from IPython import embed
-sb.set_context("paper")
 
 
 def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1):
@@ -31,10 +29,15 @@ def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1):
     fig.set_facecolor("white")
     
     ax = fig.add_subplot(111)
-    ax.plot(time, sta, color="darkblue")
+    ax.plot(time, sta, color="darkblue", lw=1)
     ax.set_xlabel("time [s]")
     ax.set_ylabel("stimulus")
     ax.xaxis.grid('off')
+    ax.spines["right"].set_visible(False)
+    ax.spines["top"].set_visible(False)
+    ax.yaxis.set_ticks_position('left')
+    ax.xaxis.set_ticks_position('bottom')
+
     ylim = ax.get_ylim()
     xlim = ax.get_xlim()
     ax.plot(list(xlim), [0., 0.], zorder=1, color='darkgray', ls='--')
@@ -56,13 +59,17 @@ def reconstruct_stimulus(spike_times, sta, stimulus, t_max=30., dt=1e-4):
         s_est[i, :] = np.convolve(y, sta, mode='same')
 
     plt.plot(np.arange(0, t_max, dt), stimulus[:,1], label='stimulus', color='darkblue', lw=2.)
-    plt.plot(np.arange(0, t_max, dt), np.mean(s_est, axis=0), label='reconstruction', color='silver', lw=1.5)
+    plt.plot(np.arange(0, t_max, dt), np.mean(s_est, axis=0), label='reconstruction', color='gray', lw=1.5)
     plt.xlabel('time[s]')
     plt.ylabel('stimulus')
     plt.xlim([0.0, 0.25])
     plt.ylim([-1., 1.])
     plt.legend()
-    plt.plot([0.0, 0.25], [0., 0.], color="darkgray", lw=1, ls='--', zorder=1)
+    plt.plot([0.0, 0.25], [0., 0.], color="darkgray", lw=1.5, ls='--', zorder=1)
+    plt.gca().spines["right"].set_visible(False)
+    plt.gca().spines["top"].set_visible(False)
+    plt.gca().yaxis.set_ticks_position('left')
+    plt.gca().xaxis.set_ticks_position('bottom')
 
     fig = plt.gcf()
     fig.set_size_inches(7.5, 5)