This repository has been archived on 2021-05-17. You can view files and clone it, but cannot push or open issues or pull requests.
scientificComputing/pointprocesses/lecture/firingrates.py

274 lines
9.4 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as spio
import scipy.stats as spst
import scipy as sp
from IPython import embed
def set_axis_fontsize(axis, label_size, tick_label_size=None, legend_size=None):
"""
Sets axis, tick label and legend font sizes to the desired size.
:param axis: the axes object
:param label_size: the size of the axis label
:param tick_label_size: the size of the tick labels. If None, lable_size is used
:param legend_size: the size of the font used in the legend.If None, the label_size is used.
"""
if not tick_label_size:
tick_label_size = label_size
if not legend_size:
legend_size = label_size
axis.xaxis.get_label().set_fontsize(label_size)
axis.yaxis.get_label().set_fontsize(label_size)
for tick in axis.xaxis.get_major_ticks() + axis.yaxis.get_major_ticks():
tick.label.set_fontsize(tick_label_size)
l = axis.get_legend()
if l:
for t in l.get_texts():
t.set_fontsize(legend_size)
def get_instantaneous_rate(times, max_t=30., dt=1e-4):
time = np.arange(0., max_t, dt)
indices = np.asarray(times / dt, dtype=int)
intervals = np.diff(np.hstack(([0], times)))
inst_rate = np.zeros(time.shape)
for i, index in enumerate(indices[1:]):
inst_rate[indices[i-1]:indices[i]] = 1/intervals[i]
return time, inst_rate
def plot_isi_rate(spike_times, max_t=30, dt=1e-4):
times = np.squeeze(spike_times[0][0])[:50000]
time, rate = get_instantaneous_rate(times, max_t=50000*dt)
rates = np.zeros((len(rate), len(spike_times)))
for i in range(len(spike_times)):
_, rates[:, i] = get_instantaneous_rate(np.squeeze(spike_times[i][0])[:50000],
max_t=50000*dt)
avg_rate = np.mean(rates, axis=1)
rate_std = np.std(rates, axis=1)
fig = plt.figure()
ax1 = fig.add_subplot(311)
ax2 = fig.add_subplot(312)
ax3 = fig.add_subplot(313)
ax1.vlines(times[times < (50000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
ax1.set_ylabel("skpikes", fontsize=12)
set_axis_fontsize(ax1, 12)
ax1.set_xlim([0, 5])
ax2.plot(time, rate, label="instantaneous rate, trial 1")
ax2.set_ylabel("firing rate [Hz]", fontsize=12)
ax2.legend(fontsize=12)
set_axis_fontsize(ax2, 12)
ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
alpha=0.5, label="standard deviation")
ax3.plot(time, avg_rate, label="average rate")
ax3.set_xlabel("times [s]", fontsize=12)
ax3.set_ylabel("firing rate [Hz]", fontsize=12)
ax3.legend(fontsize=12)
ax3.set_ylim([0, 450])
set_axis_fontsize(ax3, 12)
fig.set_size_inches(15, 10)
fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
fig.set_facecolor("white")
fig.savefig("figures/instantaneous_rate.png")
plt.close()
def get_binned_rate(times, bin_width=0.05, max_t=30., dt=1e-4):
time = np.arange(0., max_t, dt)
bins = np.arange(0., max_t, bin_width)
bin_indices = bins / dt
hist, _ = sp.histogram(times, bins)
rate = np.zeros(time.shape)
for i, b in enumerate(bin_indices[1:]):
rate[bin_indices[i-1]:b] = hist[i-1]/bin_width
return time, rate
def plot_bin_rate(spike_times, bin_width, max_t=30, dt=1e-4):
times = np.squeeze(spike_times[0][0])
time, rate = get_binned_rate(times)
rates = np.zeros((len(rate), len(spike_times)))
for i in range(len(spike_times)):
_, rates[:, i] = get_binned_rate(np.squeeze(spike_times[i][0]))
avg_rate = np.mean(rates, axis=1)
rate_std = np.std(rates, axis=1)
fig = plt.figure()
ax1 = fig.add_subplot(311)
ax2 = fig.add_subplot(312)
ax3 = fig.add_subplot(313)
ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
ax1.set_ylabel("skpikes", fontsize=12)
ax1.set_xlim([0, 5])
set_axis_fontsize(ax1, 12)
ax2.plot(time, rate, label="binned rate, trial 1")
ax2.set_ylabel("firing rate [Hz]", fontsize=12)
ax2.legend(fontsize=12)
ax2.set_xlim([0, 5])
set_axis_fontsize(ax2, 12)
ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
alpha=0.5, label="standard deviation")
ax3.plot(time, avg_rate, label="average rate")
ax3.set_xlabel("times [s]", fontsize=12)
ax3.set_ylabel("firing rate [Hz]", fontsize=12)
ax3.legend(fontsize=12)
ax3.set_xlim([0, 5])
ax3.set_ylim([0, 450])
set_axis_fontsize(ax3, 12)
fig.set_size_inches(15, 10)
fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
fig.set_facecolor("white")
fig.savefig("figures/binned_rate.png")
plt.close()
def get_convolved_rate(times, sigma, max_t=30., dt=1.e-4):
time = np.arange(0., max_t, dt)
kernel = spst.norm.pdf(np.arange(-8*sigma, 8*sigma, dt),loc=0,scale=sigma)
indices = np.asarray(times/dt, dtype=int)
rate = np.zeros(time.shape)
rate[indices] = 1.;
conv_rate = np.convolve(rate, kernel, mode="same")
return time, conv_rate
def plot_conv_rate(spike_times, sigma=0.05, max_t=30, dt=1e-4):
times = np.squeeze(spike_times[0][0])
time, rate = get_convolved_rate(times, sigma)
rates = np.zeros((len(rate), len(spike_times)))
for i in range(len(spike_times)):
_, rates[:, i] = get_convolved_rate(np.squeeze(spike_times[i][0]), sigma)
avg_rate = np.mean(rates, axis=1)
rate_std = np.std(rates, axis=1)
fig = plt.figure()
ax1 = fig.add_subplot(311)
ax2 = fig.add_subplot(312)
ax3 = fig.add_subplot(313)
ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
ax1.set_ylabel("skpikes", fontsize=12)
ax1.set_xlim([0, 5])
set_axis_fontsize(ax1, 12)
ax2.plot(time, rate, label="convolved rate, trial 1")
ax2.set_ylabel("firing rate [Hz]", fontsize=12)
ax2.legend(fontsize=12)
ax2.set_xlim([0, 5])
set_axis_fontsize(ax2, 12)
ax3.fill_between(time, avg_rate+rate_std, avg_rate-rate_std, color="dodgerblue",
alpha=0.5, label="standard deviation")
ax3.plot(time, avg_rate, label="average rate")
ax3.set_xlabel("times [s]", fontsize=12)
ax3.set_ylabel("firing rate [Hz]", fontsize=12)
ax3.legend(fontsize=12)
ax3.set_xlim([0, 5])
ax3.set_ylim([0, 450])
set_axis_fontsize(ax3, 12)
fig.set_size_inches(15, 10)
fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95)
fig.set_facecolor("white")
fig.savefig("figures/convolved_rate.png")
plt.close()
def plot_comparison(spike_times, bin_width, sigma, max_t=30., dt=1e-4):
times = np.squeeze(spike_times[0][0])
time, conv_rate = get_convolved_rate(times, bin_width/np.sqrt(12.))
time, inst_rate = get_instantaneous_rate(times)
time, binn_rate = get_binned_rate(times, bin_width)
fig = plt.figure()
ax1 = fig.add_subplot(411)
ax2 = fig.add_subplot(412)
ax3 = fig.add_subplot(413)
ax4 = fig.add_subplot(414)
ax1.spines["right"].set_visible(False)
ax1.spines["top"].set_visible(False)
ax1.yaxis.set_ticks_position('left')
ax1.xaxis.set_ticks_position('bottom')
ax2.spines["right"].set_visible(False)
ax2.spines["top"].set_visible(False)
ax2.yaxis.set_ticks_position('left')
ax2.xaxis.set_ticks_position('bottom')
ax3.spines["right"].set_visible(False)
ax3.spines["top"].set_visible(False)
ax3.yaxis.set_ticks_position('left')
ax3.xaxis.set_ticks_position('bottom')
ax4.spines["right"].set_visible(False)
ax4.spines["top"].set_visible(False)
ax4.yaxis.set_ticks_position('left')
ax4.xaxis.set_ticks_position('bottom')
ax1.vlines(times[times < (100000*dt)], ymin=0, ymax=1, color="dodgerblue", lw=1.5)
ax1.set_ylabel("spikes", fontsize=10)
ax1.set_xlim([2.5, 3.5])
ax1.set_ylim([0, 1])
ax1.set_yticks([0, 1])
set_axis_fontsize(ax1, 10)
ax1.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
ax1.set_xticklabels([])
ax2.plot(time, inst_rate, lw=1.5, label="instantaneous rate")
ax2.set_ylabel("firing rate [Hz]", fontsize=10)
ax2.legend(fontsize=10)
ax2.set_xlim([2.5, 3.5])
ax2.set_ylim([0, 300])
ax2.set_yticks(np.arange(0, 400, 100))
set_axis_fontsize(ax2, 10)
ax2.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
ax2.set_xticklabels([])
ax3.plot(time, binn_rate, lw=1.5, label="binned rate")
ax3.set_ylabel("firing rate [Hz]", fontsize=10)
ax3.legend(fontsize=10)
ax3.set_xlim([2.5, 3.5])
ax3.set_ylim([0, 300])
ax3.set_yticks(np.arange(0, 400, 100))
set_axis_fontsize(ax3, 10)
ax3.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
ax3.set_xticklabels([])
ax4.plot(time, conv_rate, lw=1.5, label="convolved rate")
ax4.set_xlabel("time [s]", fontsize=10)
ax4.set_ylabel("firing rate [Hz]", fontsize=10)
ax4.legend(fontsize=10)
ax4.set_xlim([2.5, 3.5])
ax4.set_xticks([2.5, 2.75, 3.0, 3.25, 3.5])
ax4.set_ylim([0, 300])
ax4.set_yticks(np.arange(0, 400, 100))
set_axis_fontsize(ax4, 10)
fig.set_size_inches(7.5, 5)
fig.subplots_adjust(left=0.1, bottom=0.125, top=0.95, right=0.95, )
fig.set_facecolor("white")
fig.savefig("firingrates.pdf")
plt.close()
if __name__ == "__main__":
spike_times = spio.loadmat('lifoustim.mat')["spikes"]
# plot_isi_rate(spike_times)
# plot_bin_rate(spike_times, 0.05)
# plot_conv_rate(spike_times, 0.025)
plot_comparison(spike_times, 0.05, 0.025)