From 1cc4407edee414b169d06cf15ad0e8e29823c098 Mon Sep 17 00:00:00 2001
From: Jan Benda <jan.benda@uni-tuebingen.de>
Date: Mon, 6 Jan 2020 22:38:38 +0100
Subject: [PATCH] [pointprocesses] adapted plot scripts

---
 plotstyle.py                                 |  14 +-
 pointprocesses/lecture/firingrates.py        | 200 ++++++-------------
 pointprocesses/lecture/isihexamples.py       |  29 ++-
 pointprocesses/lecture/isimethod.py          | 112 ++++-------
 pointprocesses/lecture/returnmapexamples.py  |  38 +---
 pointprocesses/lecture/serialcorrexamples.py |  42 +---
 pointprocesses/lecture/sta.py                |  39 +---
 7 files changed, 148 insertions(+), 326 deletions(-)

diff --git a/plotstyle.py b/plotstyle.py
index b31894a..0bfe3ee 100644
--- a/plotstyle.py
+++ b/plotstyle.py
@@ -265,7 +265,7 @@ def __axes__init__(ax, *args, **kwargs):
     ax.show_spines('lb')
 
 
-def __axis_label(label, unit=None):
+def axis_label(label, unit=None):
     """ Format an axis label from a label and a unit
     
     Parameters
@@ -291,7 +291,7 @@ def __axis_label(label, unit=None):
 def set_xlabel(ax, label, unit=None, **kwargs):
     """ Format the xlabel from a label and an unit.
 
-    Uses the __axis_label() function to format the axis label.
+    Uses the axis_label() function to format the axis label.
     
     Parameters
     ----------
@@ -302,13 +302,13 @@ def set_xlabel(ax, label, unit=None, **kwargs):
     kwargs: key-word arguments
         Further arguments passed on to the set_xlabel() function.
     """
-    ax.set_xlabel_orig(__axis_label(label, unit), **kwargs)
+    ax.set_xlabel_orig(axis_label(label, unit), **kwargs)
 
         
 def set_ylabel(ax, label, unit=None, **kwargs):
     """ Format the ylabel from a label and an unit.
 
-    Uses the __axis_label() function to format the axis label.
+    Uses the axis_label() function to format the axis label.
     
     Parameters
     ----------
@@ -319,13 +319,13 @@ def set_ylabel(ax, label, unit=None, **kwargs):
     kwargs: key-word arguments
         Further arguments passed on to the set_ylabel() function.
     """
-    ax.set_ylabel_orig(__axis_label(label, unit), **kwargs)
+    ax.set_ylabel_orig(axis_label(label, unit), **kwargs)
 
         
 def set_zlabel(ax, label, unit=None, **kwargs):
     """ Format the zlabel from a label and an unit.
 
-    Uses the __axis_label() function to format the axis label.
+    Uses the axis_label() function to format the axis label.
     
     Parameters
     ----------
@@ -336,7 +336,7 @@ def set_zlabel(ax, label, unit=None, **kwargs):
     kwargs: key-word arguments
         Further arguments passed on to the set_zlabel() function.
     """
-    ax.set_zlabel_orig(__axis_label(label, unit), **kwargs)
+    ax.set_zlabel_orig(axis_label(label, unit), **kwargs)
 
 
 def common_format():
diff --git a/pointprocesses/lecture/firingrates.py b/pointprocesses/lecture/firingrates.py
index d6d1666..2ee4c17 100644
--- a/pointprocesses/lecture/firingrates.py
+++ b/pointprocesses/lecture/firingrates.py
@@ -1,34 +1,9 @@
 import numpy as np
-import matplotlib.pyplot as plt
 import scipy.io as spio
 import scipy.stats as spst
 import scipy as sp
-from IPython import embed
-
-
-def set_axis_fontsize(axis, label_size, tick_label_size=None, legend_size=None):
-    """
-        Sets axis, tick label and legend font sizes to the desired size.
-
-    :param axis: the axes object
-    :param label_size: the size of the axis label
-    :param tick_label_size: the size of the tick labels. If None, lable_size is used
-    :param legend_size: the size of the font used in the legend.If None, the label_size is used.
-
-    """
-    if not tick_label_size:
-        tick_label_size = label_size
-    if not legend_size:
-        legend_size = label_size
-    axis.xaxis.get_label().set_fontsize(label_size)
-    axis.yaxis.get_label().set_fontsize(label_size)
-    for tick in axis.xaxis.get_major_ticks() + axis.yaxis.get_major_ticks():
-        tick.label.set_fontsize(tick_label_size)
-
-    l = axis.get_legend()
-    if l:
-        for t in l.get_texts():
-            t.set_fontsize(legend_size)
+import matplotlib.pyplot as plt
+from plotstyle import *
 
 
 def get_instantaneous_rate(times, max_t=30., dt=1e-4):
@@ -53,33 +28,24 @@ def plot_isi_rate(spike_times, max_t=30, dt=1e-4):
     avg_rate = np.mean(rates, axis=1)
     rate_std = np.std(rates, axis=1)
     
-    fig = plt.figure()
-    ax1 = fig.add_subplot(311)
-    ax2 = fig.add_subplot(312)
-    ax3 = fig.add_subplot(313)
+    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=cm_size(figure_width, 1.2*figure_height))
     
     ax1.vlines(times[times < (50000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
-    ax1.set_ylabel("skpikes", fontsize=12)
-    set_axis_fontsize(ax1, 12)
-    ax1.set_xlim([0, 5])
+    ax1.set_ylabel('Spikes')
+    ax1.set_xlim(0, 5)
     ax2.plot(time, rate, label="instantaneous rate, trial 1")
-    ax2.set_ylabel("firing rate [Hz]", fontsize=12)
-    ax2.legend(fontsize=12)
-    set_axis_fontsize(ax2, 12)
+    ax2.set_ylabel('Firing rate', 'Hz')
+    ax2.legend()
 
     ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
                      alpha=0.5, label="standard deviation")
     ax3.plot(time, avg_rate, label="average rate")
-    ax3.set_xlabel("times [s]", fontsize=12)
-    ax3.set_ylabel("firing rate [Hz]", fontsize=12)
-    ax3.legend(fontsize=12)
-    ax3.set_ylim([0, 450])
-    set_axis_fontsize(ax3, 12)
+    ax3.set_xlabel('Time', 's')
+    ax3.set_ylabel('Firing rate', 'Hz')
+    ax3.legend()
+    ax3.set_ylim(0, 450)
     
-    fig.set_size_inches(15, 10)
-    fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
-    fig.set_facecolor("white")
-    fig.savefig("figures/instantaneous_rate.png")
+    fig.savefig("isimethod.pdf")
     plt.close()
 
 
@@ -104,36 +70,27 @@ def plot_bin_rate(spike_times, bin_width, max_t=30, dt=1e-4):
     avg_rate = np.mean(rates, axis=1)
     rate_std = np.std(rates, axis=1)
     
-    fig = plt.figure()
-    ax1 = fig.add_subplot(311)
-    ax2 = fig.add_subplot(312)
-    ax3 = fig.add_subplot(313)
+    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=cm_size(figure_width, 1.2*figure_height))
 
     ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
-    ax1.set_ylabel("skpikes", fontsize=12)
-    ax1.set_xlim([0, 5])
-    set_axis_fontsize(ax1, 12)
+    ax1.set_ylabel('Skpikes')
+    ax1.set_xlim(0, 5)
 
     ax2.plot(time, rate, label="binned rate, trial 1")
-    ax2.set_ylabel("firing rate [Hz]", fontsize=12)
-    ax2.legend(fontsize=12)
-    ax2.set_xlim([0, 5])
-    set_axis_fontsize(ax2, 12)
+    ax2.set_ylabel('Firing rate', 'Hz')
+    ax2.legend()
+    ax2.set_xlim(0, 5)
 
     ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
                      alpha=0.5, label="standard deviation")
     ax3.plot(time, avg_rate, label="average rate")
-    ax3.set_xlabel("times [s]", fontsize=12)
-    ax3.set_ylabel("firing rate [Hz]", fontsize=12)
-    ax3.legend(fontsize=12)
-    ax3.set_xlim([0, 5])
-    ax3.set_ylim([0, 450])
-    set_axis_fontsize(ax3, 12)
+    ax3.set_xlabel('Times', 's')
+    ax3.set_ylabel('Firing rate', 'Hz')
+    ax3.legend()
+    ax3.set_xlim(0, 5)
+    ax3.set_ylim(0, 450)
 
-    fig.set_size_inches(15, 10)
-    fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
-    fig.set_facecolor("white")
-    fig.savefig("figures/binned_rate.png")
+    fig.savefig("binmethod.pdf")
     plt.close()
   
 
@@ -157,36 +114,27 @@ def plot_conv_rate(spike_times, sigma=0.05, max_t=30, dt=1e-4):
     avg_rate = np.mean(rates, axis=1)
     rate_std = np.std(rates, axis=1)
     
-    fig = plt.figure()
-    ax1 = fig.add_subplot(311)
-    ax2 = fig.add_subplot(312)
-    ax3 = fig.add_subplot(313)
+    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=cm_size(figure_width, 1.2*figure_height))
 
     ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
-    ax1.set_ylabel("skpikes", fontsize=12)
-    ax1.set_xlim([0, 5])
-    set_axis_fontsize(ax1, 12)
+    ax1.set_ylabel('Skpikes')
+    ax1.set_xlim(0, 5)
 
     ax2.plot(time, rate, label="convolved rate, trial 1")
-    ax2.set_ylabel("firing rate [Hz]", fontsize=12)
-    ax2.legend(fontsize=12)
-    ax2.set_xlim([0, 5])
-    set_axis_fontsize(ax2, 12)
+    ax2.set_ylabel('Firing rate', 'Hz')
+    ax2.legend()
+    ax2.set_xlim(0, 5)
 
     ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
                      alpha=0.5, label="standard deviation")
     ax3.plot(time, avg_rate, label="average rate")
-    ax3.set_xlabel("times [s]", fontsize=12)
-    ax3.set_ylabel("firing rate [Hz]", fontsize=12)
-    ax3.legend(fontsize=12)
-    ax3.set_xlim([0, 5])
-    ax3.set_ylim([0, 450])
-    set_axis_fontsize(ax3, 12)
+    ax3.set_xlabel('Times', 's')
+    ax3.set_ylabel('Firing rate', 'Hz')
+    ax3.legend()
+    ax3.set_xlim(0, 5)
+    ax3.set_ylim(0, 450)
 
-    fig.set_size_inches(15, 10)
-    fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
-    fig.set_facecolor("white")
-    fig.savefig("figures/convolved_rate.png")
+    fig.savefig("convmethod.pdf")
     plt.close()
 
 
@@ -196,78 +144,52 @@ def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
     time, inst_rate = get_instantaneous_rate(times)
     time, binn_rate = get_binned_rate(times, bin_width)
    
-    fig = plt.figure()
-    ax1 = fig.add_subplot(411)
-    ax2 = fig.add_subplot(412)
-    ax3 = fig.add_subplot(413)
-    ax4 = fig.add_subplot(414)
-    ax1.spines["right"].set_visible(False)
-    ax1.spines["top"].set_visible(False)
-    ax1.yaxis.set_ticks_position('left')
-    ax1.xaxis.set_ticks_position('bottom')
-    ax2.spines["right"].set_visible(False)
-    ax2.spines["top"].set_visible(False)
-    ax2.yaxis.set_ticks_position('left')
-    ax2.xaxis.set_ticks_position('bottom')
-    ax3.spines["right"].set_visible(False)
-    ax3.spines["top"].set_visible(False)
-    ax3.yaxis.set_ticks_position('left')
-    ax3.xaxis.set_ticks_position('bottom')
-    ax4.spines["right"].set_visible(False)
-    ax4.spines["top"].set_visible(False)
-    ax4.yaxis.set_ticks_position('left')
-    ax4.xaxis.set_ticks_position('bottom')
+    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=cm_size(figure_width, 2.0*figure_height))
+    fig.subplots_adjust(**adjust_fs(fig, left=6.0, right=1.5, bottom=3.0, top=1.0))
 
-    ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
-    ax1.set_ylabel("spikes", fontsize=10)
-    ax1.set_xlim([2.5, 3.5])
-    ax1.set_ylim([0, 1])
+    ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color=colors['blue'], lw=1.5)
+    ax1.set_ylabel('Spikes')
+    ax1.set_xlim(2.5, 3.5)
+    ax1.set_ylim(0, 1)
     ax1.set_yticks([0, 1])
-    set_axis_fontsize(ax1, 10)
     ax1.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
     ax1.set_xticklabels([])
 
-    ax2.plot(time, inst_rate, lw=1.5, label="instantaneous rate")
-    ax2.set_ylabel("firing rate [Hz]", fontsize=10)
-    ax2.legend(fontsize=10)
-    ax2.set_xlim([2.5, 3.5])
-    ax2.set_ylim([0, 300])
+    ax2.plot(time, inst_rate, color=colors['orange'], lw=1.5, label="instantaneous rate")
+    ax2.legend()
+    ax2.set_ylabel('Rate', 'Hz')
+    ax2.set_xlim(2.5, 3.5)
+    ax2.set_ylim(0, 300)
     ax2.set_yticks(np.arange(0, 400, 100))
-    set_axis_fontsize(ax2, 10)
     ax2.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
     ax2.set_xticklabels([])
 
-    ax3.plot(time, binn_rate, lw=1.5, label="binned rate")
-    ax3.set_ylabel("firing rate [Hz]", fontsize=10)
-    ax3.legend(fontsize=10)
-    ax3.set_xlim([2.5, 3.5])
-    ax3.set_ylim([0, 300])
+    ax3.plot(time, binn_rate, color=colors['orange'], lw=1.5, label="binned rate")
+    ax3.set_ylabel('Rate', 'Hz')
+    ax3.legend()
+    ax3.set_xlim(2.5, 3.5)
+    ax3.set_ylim(0, 300)
     ax3.set_yticks(np.arange(0, 400, 100))
-    set_axis_fontsize(ax3, 10)
     ax3.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
     ax3.set_xticklabels([])
 
-    ax4.plot(time, conv_rate, lw=1.5, label="convolved rate")
-    ax4.set_xlabel("time [s]", fontsize=10)
-    ax4.set_ylabel("firing rate [Hz]", fontsize=10)
-    ax4.legend(fontsize=10)
-    ax4.set_xlim([2.5, 3.5])
+    ax4.plot(time, conv_rate, color=colors['orange'], lw=1.5, label="convolved rate")
+    ax4.set_xlabel('Time', 's')
+    ax4.set_ylabel('Rate', 'Hz')
+    ax4.legend()
+    ax4.set_xlim(2.5, 3.5)
     ax4.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
-    ax4.set_ylim([0, 300])
+    ax4.set_ylim(0, 300)
     ax4.set_yticks(np.arange(0, 400, 100))
-    set_axis_fontsize(ax4, 10)
 
-    fig.set_size_inches(7.5, 5)
-    fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95, )
-    fig.set_facecolor("white")
     fig.savefig("firingrates.pdf")
     plt.close()
 
 
 if __name__ == "__main__":
     spike_times =  spio.loadmat('lifoustim.mat')["spikes"]
-    # plot_isi_rate(spike_times)
-    # plot_bin_rate(spike_times, 0.05)
-    # plot_conv_rate(spike_times, 0.025)
+    #plot_isi_rate(spike_times)
+    #plot_bin_rate(spike_times, 0.05)
+    #plot_conv_rate(spike_times, 0.025)
     plot_comparison(spike_times, 0.05, 0.025)
 
diff --git a/pointprocesses/lecture/isihexamples.py b/pointprocesses/lecture/isihexamples.py
index 7d8d30e..b1c7daa 100644
--- a/pointprocesses/lecture/isihexamples.py
+++ b/pointprocesses/lecture/isihexamples.py
@@ -1,5 +1,6 @@
 import numpy as np
 import matplotlib.pyplot as plt
+from plotstyle import *
 
 def hompoisson(rate, trials, duration) :
     spikes = []
@@ -56,9 +57,9 @@ def plotisih( ax, isis, binwidth=None ) :
     ax.text(0.9, 0.85, 'rate={:.0f}Hz'.format(1.0/np.mean(isis)), ha='right', transform=ax.transAxes)
     ax.text(0.9, 0.75, 'mean={:.0f}ms'.format(1000.0*np.mean(isis)), ha='right', transform=ax.transAxes)
     ax.text(0.9, 0.65, 'CV={:.2f}'.format(np.std(isis)/np.mean(isis)), ha='right', transform=ax.transAxes)
-    ax.set_xlabel('ISI [ms]')
-    ax.set_ylabel('p(ISI) [1/s]')
-    ax.bar( 1000.0*b[:-1], h, 1000.0*np.diff(b) )
+    ax.set_xlabel('ISI', 'ms')
+    ax.set_ylabel('p(ISI)', '1/s')
+    ax.bar( 1000.0*b[:-1], h, bar_fac*1000.0*np.diff(b), facecolor=colors['blue'])
 
 # parameter:
 rate = 20.0
@@ -83,19 +84,17 @@ x[x<0.0] = 0.0
 # pif spike trains:
 inhspikes = pifspikes(x, trials, dt, D=0.3)
 
-fig = plt.figure( figsize=(9,4) )
-ax = fig.add_subplot(1, 2, 1)
-ax.set_title('stationary')
-ax.set_xlim(0.0, 200.0)
-ax.set_ylim(0.0, 40.0)
-plotisih(ax, isis(homspikes))
+fig, (ax1, ax2) = plt.subplots(1, 2)
+fig.subplots_adjust(**adjust_fs(fig, top=1.5))
+ax1.set_title('stationary')
+ax1.set_xlim(0.0, 200.0)
+ax1.set_ylim(0.0, 40.0)
+plotisih(ax1, isis(homspikes))
 
-ax = fig.add_subplot(1, 2, 2)
-ax.set_title('non-stationary')
-ax.set_xlim(0.0, 200.0)
-ax.set_ylim(0.0, 40.0)
-plotisih(ax, isis(inhspikes))
+ax2.set_title('non-stationary')
+ax2.set_xlim(0.0, 200.0)
+ax2.set_ylim(0.0, 40.0)
+plotisih(ax2, isis(inhspikes))
 
-plt.tight_layout()
 plt.savefig('isihexamples.pdf')
 plt.close()
diff --git a/pointprocesses/lecture/isimethod.py b/pointprocesses/lecture/isimethod.py
index b0baa64..1fe81bb 100644
--- a/pointprocesses/lecture/isimethod.py
+++ b/pointprocesses/lecture/isimethod.py
@@ -1,20 +1,9 @@
-import matplotlib.pyplot as plt
 import numpy as np
-from IPython import embed
+import matplotlib.pyplot as plt
+from plotstyle import *
 
-figsize=(6,3)
 
-def set_rc():
-    plt.rcParams['xtick.major.size'] = 5
-    plt.rcParams['xtick.minor.size'] = 5
-    plt.rcParams['xtick.major.width'] = 2
-    plt.rcParams['xtick.minor.width'] = 2
-    plt.rcParams['ytick.major.size'] = 5
-    plt.rcParams['ytick.minor.size'] = 5
-    plt.rcParams['ytick.major.width'] = 2
-    plt.rcParams['ytick.minor.width'] = 2
-    plt.rcParams['xtick.direction'] = "out"
-    plt.rcParams['ytick.direction'] = "out"
+fig_size = cm_size(figure_width, 1.2*figure_height)
 
 
 def create_spikes(nspikes=11, duration=0.5, seed=1000):
@@ -38,27 +27,18 @@ def gaussian(sigma, dt):
 
     
 def setup_axis(spikes_ax, rate_ax):
-    spikes_ax.spines["left"].set_visible(False)
-    spikes_ax.spines["right"].set_visible(False)
-    spikes_ax.spines["top"].set_visible(False)
-    spikes_ax.yaxis.set_ticks_position('left')
-    spikes_ax.xaxis.set_ticks_position('bottom')
+    spikes_ax.show_spines('b')
     spikes_ax.set_yticks([])
     spikes_ax.set_ylim(-0.2, 1.0)
-    #spikes_ax.set_ylabel("Spikes")
-    spikes_ax.text(-0.1, 0.5, "Spikes", transform=spikes_ax.transAxes, rotation='vertical', va='center')
-    #spikes_ax.text(-0.125, 1.2, "A", transform=spikes_ax.transAxes)
+    spikes_ax.text(-0.1, 0.5, 'Spikes', transform=spikes_ax.transAxes, rotation='vertical', va='center')
     spikes_ax.set_xlim(-1, 500)
     #spikes_ax.set_xticklabels(np.arange(0., 600, 100))
 
-    rate_ax.spines["right"].set_visible(False)
-    rate_ax.spines["top"].set_visible(False)
-    rate_ax.yaxis.set_ticks_position('left')
-    rate_ax.xaxis.set_ticks_position('bottom')
-    rate_ax.set_xlabel('Time [ms]')
-    #rate_ax.set_ylabel('Firing rate [Hz]')
-    rate_ax.text(-0.1, 0.5, "Rate [Hz]", transform=rate_ax.transAxes, rotation='vertical', va='center')
-    #rate_ax.text(-0.125, 1.15, "B", transform=rate_ax.transAxes)   
+    spikes_ax.show_spines('lb')
+    rate_ax.set_xlabel('Time', 'ms')
+    #rate_ax.set_ylabel('Firing rate', 'Hz')
+    rate_ax.text(-0.1, 0.5, axis_label('Rate', 'Hz'), transform=rate_ax.transAxes,
+                 rotation='vertical', va='center')
     rate_ax.set_xlim(0, 500)
     #rate_ax.set_xticklabels(np.arange(0., 600, 100))
     rate_ax.set_ylim(0, 60)
@@ -75,29 +55,23 @@ def plot_bin_method():
     bins = np.arange(0, 0.55, 0.05)
     count, _ = np.histogram(spike_times, bins)
     
-    plt.xkcd()
-    set_rc()
-    fig = plt.figure()
-    fig.set_size_inches(*figsize)
-    fig.set_facecolor('white')
-    spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
+    fig = plt.figure(figsize=fig_size)
+    spikes_ax = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
     rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
-    setup_axis(spikes, rate_ax)
+    setup_axis(spikes_ax, rate_ax)
 
     for ti in spike_times:
         ti *= 1000.0
-        spikes.plot([ti, ti], [0., 1.], '-b', lw=2)
+        spikes_ax.plot([ti, ti], [0., 1.], color=colors['blue'], lw=2)
 
-    #spikes.vlines(1000.0*spike_times, 0., 1., color="darkblue", lw=1.25)
     for tb in 1000.0*bins :
-        spikes.plot([tb, tb], [-2.0, 0.75], '-', color="#777777", lw=1, clip_on=False)
-    #spikes.vlines(1000.0*np.hstack((0,bins)), -2.0, 1.25, color="#777777", lw=1, linestyles='-', clip_on=False)
+        spikes_ax.plot([tb, tb], [-2.0, 0.75], '-', color="#777777", lw=1, clip_on=False)
     for i,c in enumerate(count):
-        spikes.text(1000.0*(bins[i]+0.5*bins[1]), 1.1, str(c), color='#CC0000', ha='center')
+        spikes_ax.text(1000.0*(bins[i]+0.5*bins[1]), 1.1, str(c), color=colors['red'],
+                       ha='center')
 
     rate = count / 0.05
-    rate_ax.step(1000.0*bins, np.hstack((rate, rate[-1])), color='#FF9900', where='post')
-    fig.tight_layout()
+    rate_ax.step(1000.0*bins, np.hstack((rate, rate[-1])), color=colors['orange'], where='post')
     fig.savefig("binmethod.pdf")
     plt.close()
 
@@ -114,59 +88,45 @@ def plot_conv_method():
     rate = np.convolve(rate, kernel, mode='same')
     rate = np.roll(rate, -1)
 
-    plt.xkcd()
-    set_rc()
-    fig = plt.figure()
-    fig.set_size_inches(*figsize)
-    fig.set_facecolor('white')
-    spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
+    fig = plt.figure(figsize=fig_size)
+    spikes_ax = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
     rate_ax = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
-    setup_axis(spikes, rate_ax)
+    setup_axis(spikes_ax, rate_ax)
     
-    #spikes.vlines(1000.0*spike_times, 0., 1., color="darkblue", lw=1.5, zorder=2)
     for ti in spike_times:
         ti *= 1000.0
-        spikes.plot([ti, ti], [0., 1.], '-b', lw=2)
-        spikes.plot(1000*kernel_time + ti, kernel/np.max(kernel), color='#cc0000', lw=1, zorder=1)
-
-    rate_ax.plot(1000.0*t, rate, color='#FF9900', lw=2, zorder=2)
-    rate_ax.fill_between(1000.0*t, rate, np.zeros(len(rate)), color='#FFFF66')
-    #rate_ax.fill_between(t, rate, np.zeros(len(rate)), color="red", alpha=0.5)
-    #rate_ax.set_ylim([0, 50])
-    #rate_ax.set_yticks(np.arange(0,75,25))
-    fig.tight_layout()
+        spikes_ax.plot([ti, ti], [0., 1.], color=colors['blue'], lw=2)
+        spikes_ax.plot(1000*kernel_time + ti, kernel/np.max(kernel), color=colors['red'],
+                       lw=1, zorder=1)
+
+    rate_ax.plot(1000.0*t, rate, color=colors['orange'], lw=2, zorder=2)
+    rate_ax.fill_between(1000.0*t, rate, np.zeros(len(rate)), color=colors['yellow'])
+    
     fig.savefig("convmethod.pdf")
 
 
 def plot_isi_method():
     spike_times = create_spikes()
     
-    plt.xkcd()
-    set_rc()
-    fig = plt.figure()
-    fig.set_size_inches(*figsize)
-    fig.set_facecolor('white')
-    
-    spikes = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
+    fig = plt.figure(figsize=fig_size)
+    spikes_ax = plt.subplot2grid((7,1), (0,0), rowspan=3, colspan=1)
     rate = plt.subplot2grid((7,1), (3,0), rowspan=4, colspan=1)
-    setup_axis(spikes, rate)
+    setup_axis(spikes_ax, rate)
     
     spike_times = np.hstack((0.005, spike_times))
-    #spikes.vlines(1000*spike_times, 0., 1., color="blue", lw=2)
     for i in range(1, len(spike_times)):
         t_start = 1000*spike_times[i-1]
         t = 1000*spike_times[i]
-        spikes.plot([t_start, t_start], [0., 1.], '-b', lw=2)
-        spikes.annotate(s='', xy=(t_start, 0.5), xytext=(t,0.5), arrowprops=dict(arrowstyle='<->'), color='red')
-        spikes.text(0.5*(t_start+t), 0.75, 
+        spikes_ax.plot([t_start, t_start], [0., 1.], color=colors['blue'], lw=2)
+        spikes_ax.annotate(s='', xy=(t_start, 0.5), xytext=(t,0.5), arrowprops=dict(arrowstyle='<->'), color=colors['red'])
+        spikes_ax.text(0.5*(t_start+t), 0.75, 
                     "{0:.0f}".format((t - t_start)),
-                    color='#CC0000', ha='center')
+                    color=colors['red'], ha='center')
 
     #spike_times = np.hstack((0, spike_times))
     i_rate = 1./np.diff(spike_times)
-    rate.step(1000*spike_times, np.hstack((i_rate, i_rate[-1])),color='#FF9900', lw=2, where="post")
+    rate.step(1000*spike_times, np.hstack((i_rate, i_rate[-1])),color=colors['orange'], lw=2, where="post")
     
-    fig.tight_layout()
     fig.savefig("isimethod.pdf")
     
     
diff --git a/pointprocesses/lecture/returnmapexamples.py b/pointprocesses/lecture/returnmapexamples.py
index 829cf6d..db9326d 100644
--- a/pointprocesses/lecture/returnmapexamples.py
+++ b/pointprocesses/lecture/returnmapexamples.py
@@ -1,5 +1,6 @@
 import numpy as np
 import matplotlib.pyplot as plt
+from plotstyle import *
 
 def hompoisson(rate, trials, duration) :
     spikes = []
@@ -45,28 +46,13 @@ def isis( spikes ) :
         isi.extend(np.diff(spikes[k]))
     return np.array( isi )
 
-def plotisih( ax, isis, binwidth=None ) :
-    if binwidth == None :
-        nperbin = 200.0    # average number of isis per bin
-        bins = len(isis)/nperbin  # number of bins
-        binwidth = np.max(isis)/bins
-        if binwidth < 5e-4 :     # half a millisecond
-            binwidth = 5e-4
-    h, b = np.histogram(isis, np.arange(0.0, np.max(isis)+binwidth, binwidth), density=True)
-    ax.text(0.9, 0.85, 'rate={:.0f}Hz'.format(1.0/np.mean(isis)), ha='right', transform=ax.transAxes)
-    ax.text(0.9, 0.75, 'mean={:.0f}ms'.format(1000.0*np.mean(isis)), ha='right', transform=ax.transAxes)
-    ax.text(0.9, 0.65, 'CV={:.2f}'.format(np.std(isis)/np.mean(isis)), ha='right', transform=ax.transAxes)
-    ax.set_xlabel('ISI [ms]')
-    ax.set_ylabel('p(ISI) [1/s]')
-    ax.bar( 1000.0*b[:-1], h, 1000.0*np.diff(b) )
-
 def plotreturnmap(ax, isis, lag=1, max=None) :
-    ax.set_xlabel(r'ISI$_i$ [ms]')
-    ax.set_ylabel(r'ISI$_{i+1}$ [ms]')
+    ax.set_xlabel(r'ISI$_i$', 'ms')
+    ax.set_ylabel(r'ISI$_{i+1}$', 'ms')
     if max != None :
         ax.set_xlim(0.0, 1000.0*max)
         ax.set_ylim(0.0, 1000.0*max)
-    ax.scatter( 1000.0*isis[:-lag], 1000.0*isis[lag:] )
+    ax.scatter(1000.0*isis[:-lag], 1000.0*isis[lag:], c=colors['blue'])
 
 # parameter:
 rate = 20.0
@@ -91,15 +77,13 @@ x[x<0.0] = 0.0
 # pif spike trains:
 inhspikes = pifspikes(x, trials, dt, D=0.3)
 
-fig = plt.figure( figsize=(9,4) )
-ax = fig.add_subplot(1, 2, 1)
-ax.set_title('stationary')
-plotreturnmap(ax, isis(homspikes), 1, 0.3)
+fig, (ax1, ax2) = plt.subplots(1, 2)
+fig.subplots_adjust(**adjust_fs(fig, left=6.5, top=1.5))
+ax1.set_title('stationary')
+plotreturnmap(ax1, isis(homspikes), 1, 0.3)
 
-ax = fig.add_subplot(1, 2, 2)
-ax.set_title('non-stationary')
-plotreturnmap(ax, isis(inhspikes), 1, 0.3)
+ax2.set_title('non-stationary')
+plotreturnmap(ax2, isis(inhspikes), 1, 0.3)
 
-plt.tight_layout()
 plt.savefig('returnmapexamples.pdf')
-#plt.show()
+plt.close()
diff --git a/pointprocesses/lecture/serialcorrexamples.py b/pointprocesses/lecture/serialcorrexamples.py
index a2dc853..d45a6be 100644
--- a/pointprocesses/lecture/serialcorrexamples.py
+++ b/pointprocesses/lecture/serialcorrexamples.py
@@ -1,5 +1,6 @@
 import numpy as np
 import matplotlib.pyplot as plt
+from plotstyle import *
 
 def hompoisson(rate, trials, duration) :
     spikes = []
@@ -45,29 +46,6 @@ def isis( spikes ) :
         isi.extend(np.diff(spikes[k]))
     return np.array( isi )
 
-def plotisih( ax, isis, binwidth=None ) :
-    if binwidth == None :
-        nperbin = 200.0    # average number of isis per bin
-        bins = len(isis)/nperbin  # number of bins
-        binwidth = np.max(isis)/bins
-        if binwidth < 5e-4 :     # half a millisecond
-            binwidth = 5e-4
-    h, b = np.histogram(isis, np.arange(0.0, np.max(isis)+binwidth, binwidth), density=True)
-    ax.text(0.9, 0.85, 'rate={:.0f}Hz'.format(1.0/np.mean(isis)), ha='right', transform=ax.transAxes)
-    ax.text(0.9, 0.75, 'mean={:.0f}ms'.format(1000.0*np.mean(isis)), ha='right', transform=ax.transAxes)
-    ax.text(0.9, 0.65, 'CV={:.2f}'.format(np.std(isis)/np.mean(isis)), ha='right', transform=ax.transAxes)
-    ax.set_xlabel('ISI [ms]')
-    ax.set_ylabel('p(ISI) [1/s]')
-    ax.bar( 1000.0*b[:-1], h, 1000.0*np.diff(b) )
-
-def plotreturnmap(ax, isis, lag=1, max=None) :
-    ax.set_xlabel(r'ISI$_i$ [ms]')
-    ax.set_ylabel(r'ISI$_{i+1}$ [ms]')
-    if max != None :
-        ax.set_xlim(0.0, 1000.0*max)
-        ax.set_ylim(0.0, 1000.0*max)
-    ax.scatter( 1000.0*isis[:-lag], 1000.0*isis[lag:] )
-
 def plotserialcorr(ax, isis, maxlag=10) :
     lags = np.arange(maxlag+1)
     corr = [1.0]
@@ -77,7 +55,7 @@ def plotserialcorr(ax, isis, maxlag=10) :
     ax.set_ylabel(r'ISI correlation $\rho_k$')
     ax.set_xlim(0.0, maxlag)
     ax.set_ylim(-1.0, 1.0)
-    ax.plot(lags, corr, '.-', markersize=20)
+    ax.plot(lags, corr, '.-', markersize=15, c=colors['blue'])
 
 # parameter:
 rate = 20.0
@@ -102,16 +80,14 @@ x[x<0.0] = 0.0
 # pif spike trains:
 inhspikes = pifspikes(x, trials, dt, D=0.3)
 
-fig = plt.figure( figsize=(9,3) )
+fig, (ax1, ax2) = plt.subplots(1, 2)
+fig.subplots_adjust(**adjust_fs(fig, left=7.0, right=1.0))
 
-ax = fig.add_subplot(1, 2, 1)
-plotserialcorr(ax, isis(homspikes))
-ax.set_ylim(-0.2, 1.0)
+plotserialcorr(ax1, isis(homspikes))
+ax1.set_ylim(-0.2, 1.0)
 
-ax = fig.add_subplot(1, 2, 2)
-plotserialcorr(ax, isis(inhspikes))
-ax.set_ylim(-0.2, 1.0)
+plotserialcorr(ax2, isis(inhspikes))
+ax2.set_ylim(-0.2, 1.0)
 
-plt.tight_layout()
 plt.savefig('serialcorrexamples.pdf')
-#plt.show()
+plt.close()
diff --git a/pointprocesses/lecture/sta.py b/pointprocesses/lecture/sta.py
index 127c5ae..5bc9059 100644
--- a/pointprocesses/lecture/sta.py
+++ b/pointprocesses/lecture/sta.py
@@ -1,7 +1,7 @@
 import numpy as np
-import matplotlib.pyplot as plt
 import scipy.io as scio
-from IPython import embed
+import matplotlib.pyplot as plt
+from plotstyle import *
 
 
 def plot_sta(times, stim, dt, t_min=-0.1, t_max=.1):
@@ -38,30 +38,19 @@ def reconstruct_stimulus(spike_times, sta, stimulus, t_max=30., dt=1e-4):
 
 
 def plot_results(sta_time, st_average, stim_time, s_est, stimulus, duration, dt):
-    plt.xkcd()
     sta_ax = plt.subplot2grid((1, 3), (0, 0), rowspan=1, colspan=1)
     stim_ax = plt.subplot2grid((1, 3), (0, 1), rowspan=1, colspan=2)
     
     fig = plt.gcf()
-    fig.set_size_inches(8, 3)
-    fig.subplots_adjust(left=0.08, bottom=0.15, top=0.9, right=0.975)
-    fig.set_facecolor("white")
+    fig.subplots_adjust(**adjust_fs(fig, left=6.5, right=2.0, bottom=3.0, top=2.0))
 
-    sta_ax.plot(sta_time * 1000, st_average, color="#FF9900", lw=2.)
-    sta_ax.set_xlabel("Time (ms)")
-    sta_ax.set_ylabel("Stimulus")
+    sta_ax.plot(sta_time * 1000, st_average, color=colors['orange'], lw=2.)
+    sta_ax.set_xlabel('Time', 'ms')
+    sta_ax.set_ylabel('Stimulus')
     sta_ax.set_xlim(-40, 20)
     sta_ax.set_xticks(np.arange(-40, 21, 20))
     sta_ax.set_ylim(-0.1, 0.2)
     sta_ax.set_yticks(np.arange(-0.1, 0.21, 0.1))
-    #  sta_ax.xaxis.grid('off')
-    sta_ax.spines["right"].set_visible(False)
-    sta_ax.spines["top"].set_visible(False)
-    sta_ax.yaxis.set_ticks_position('left')
-    sta_ax.xaxis.set_ticks_position('bottom')
-    sta_ax.spines["bottom"].set_linewidth(2.0)
-    sta_ax.spines["left"].set_linewidth(2.0)
-    sta_ax.tick_params(direction="out", width=2.0)
     ylim = sta_ax.get_ylim()
     xlim = sta_ax.get_xlim()
     sta_ax.plot(list(xlim), [0., 0.], zorder=1, color='darkgray', ls='--', lw=1)
@@ -80,23 +69,15 @@ def plot_results(sta_time, st_average, stim_time, s_est, stimulus, duration, dt)
                     connectionstyle="angle3,angleA=60,angleB=-40") )
     #sta_ax.text(-0.25, 1.04, "A", transform=sta_ax.transAxes, size=24)
 
-    stim_ax.plot(stim_time * 1000, stimulus[:,1], label='stimulus', color='#0000FF', lw=2.)
-    stim_ax.plot(stim_time * 1000, s_est, label='reconstruction', color='#FF9900', lw=2)
-    stim_ax.set_xlabel('Time (ms)')
+    stim_ax.plot(stim_time * 1000, stimulus[:,1], label='stimulus', color=colors['blue'], lw=2.)
+    stim_ax.plot(stim_time * 1000, s_est, label='reconstruction', color=colors['orange'], lw=2)
+    stim_ax.set_xlabel('Time', 'ms')
     stim_ax.set_xlim(0.0, 200)
     stim_ax.set_ylim([-1., 1.])
-    stim_ax.legend(loc=(0.3, 0.85), frameon=False, fontsize=12)
+    stim_ax.legend(loc=(0.3, 0.85), frameon=False)
     stim_ax.plot([0.0, 250], [0., 0.], color="darkgray", lw=1, ls='--', zorder=1)
-    stim_ax.spines["right"].set_visible(False)
-    stim_ax.spines["top"].set_visible(False)
-    stim_ax.yaxis.set_ticks_position('left')
-    stim_ax.xaxis.set_ticks_position('bottom')
-    stim_ax.spines["bottom"].set_linewidth(2.0)
-    stim_ax.spines["left"].set_linewidth(2.0)
-    stim_ax.tick_params(direction="out", width=2.0)
     #stim_ax.text(-0.1, 1.04, "B", transform=stim_ax.transAxes, size=24)
 
-    fig.tight_layout()    
     fig.savefig("sta.pdf")
     plt.close()