added python files for figures

This commit is contained in:
Jan Benda 2025-05-16 08:54:32 +02:00
parent ba45a72488
commit 923982d43f
8 changed files with 1756 additions and 0 deletions

225
ampullaryexamplecell.py Normal file
View File

@ -0,0 +1,225 @@
import sys
sys.path.insert(0, 'ephys') # for analysing data
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from spectral import diag_projection, peakedness
from plotstyle import plot_style
from punitexamplecell import load_baseline, load_noise, load_spectra
from punitexamplecell import plot_colorbar
cell_name = '2012-05-15-ac'
run1 = 3 # 4
run2 = 1
base_path = Path('ephys')
data_path = base_path / 'data'
results_path = base_path / 'results'
def plot_isih(ax, s, rate, cv, isis, pdf):
ax.show_spines('b')
ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1)
ax.set_xlim(0, 12)
ax.set_xticks_delta(4)
ax.set_xlabel('ISI', 'ms')
ax.text(0, 1.08, 'Ampullary:', transform=ax.transAxes, color=s.cell_color1,
fontsize='large')
ax.text(0.95, 1.08, f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}',
transform=ax.transAxes)
def plot_response_spectrum(ax, s, eodf, rate, freqs, prr):
rate_i = np.argmax(prr[freqs < 0.7*eodf])
eod_i = np.argmax(prr[freqs > 500]) + np.argmax(freqs > 500)
power_db = 10*np.log10(prr/np.max(prr))
ax.show_spines('b')
mask = (freqs > 30) & (freqs < 890)
ax.plot(freqs[mask], power_db[mask], **s.lsC1)
ax.plot(freqs[eod_i], power_db[eod_i], **s.psA1c)
ax.plot(freqs[rate_i], power_db[rate_i], **s.psA2c)
ax.set_xlim(0, 900)
ax.set_ylim(-25, 5)
ax.set_xticks_delta(300)
ax.set_xlabel('$f$', 'Hz')
ax.text(freqs[eod_i], power_db[eod_i] + 2, '$f_{\\rm EOD}$',
ha='center')
ax.text(freqs[rate_i], power_db[rate_i] + 2, '$r$',
ha='center')
ax.yscalebar(1.05, 0, 10, 'dB', ha='right')
def plot_response(ax, s, eodf, time1, stimulus1, contrast1, spikes1, contrast2, spikes2):
t0 = 0.3
t1 = 0.4
maxtrials = 8
trials = np.arange(maxtrials)
ax.show_spines('')
ax.eventplot(spikes1[2:2+maxtrials], lineoffsets=trials - maxtrials + 1,
linelength=0.8, linewidths=1, color=s.cell_color1)
ax.eventplot(spikes2[2:2+maxtrials], lineoffsets=trials - 2*maxtrials,
linelength=0.8, linewidths=1, color=s.cell_color2)
am = contrast1*stimulus1
eod = np.sin(2*np.pi*eodf*time1) + am
ax.plot(time1, 4*eod + 7, **s.lsEOD)
ax.plot(time1, 4*(1 + am) + 7, **s.lsAM)
ax.set_xlim(t0, t1)
ax.set_ylim(-2*maxtrials - 0.5, 14)
ax.xscalebar(1, -0.05, 0.01, None, '10\\,ms', ha='right')
ax.text(t1 + 0.003, -0.5*maxtrials, f'${100*contrast1:.0f}$\\,\\%',
va='center', color=s.cell_color1)
ax.text(t1 + 0.003, -1.55*maxtrials, f'${100*contrast2:.0f}$\\,\\%',
va='center', color=s.cell_color2)
def plot_gain(ax, s, fbase, contrast1, freqs1, gain1,
contrast2, freqs2, gain2, fcutoff):
ax.axvline(fbase, **s.lsGrid)
ax.plot(freqs2, gain2, label=f'{100*contrast2:.0f}', **s.lsC2)
ax.plot(freqs1, gain1, label=f'{100*contrast1:.0f}', **s.lsC1)
ax.set_xlim(0, fcutoff)
ax.set_ylim(0, 1500)
ax.set_xticks_delta(50)
ax.set_xlabel('$f$', 'Hz')
ax.set_ylabel(r'$|\chi_1|$', 'Hz')
ax.text(fbase, 1550, '$r$', ha='center')
def plot_chi2(ax, s, contrast, freqs, chi2, fcutoff, vmax):
ax.set_aspect('equal')
if vmax is None:
vmax = np.quantile(1e-3*chi2, 0.99)
pc = ax.pcolormesh(freqs, freqs, 1e-3*chi2, vmin=0, vmax=vmax,
cmap='viridis', rasterized=True, zorder=10)
ax.set_xlim(0, fcutoff)
ax.set_ylim(0, fcutoff)
df = 100 if fcutoff == 300 else 50
ax.set_xticks_delta(df)
ax.set_yticks_delta(df)
ax.set_xlabel('$f_1$', 'Hz')
ax.set_ylabel('$f_2$', 'Hz')
return pc
def plot_diagonals(ax, s, fbase, contrast1, freqs1, chi21, contrast2, freqs2, chi22, fcutoff):
diags = []
nlis = []
nlips = []
nlifs = []
for contrast, freqs, chi2 in [[contrast1, freqs1, chi21], [contrast2, freqs2, chi22]]:
dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff)
diags.append([dfreqs, diag])
nli, nlif = peakedness(dfreqs, diag, fbase, median=False)
nlip = diag[np.argmin(np.abs(dfreqs - nlif))]
nlis.append(nli)
nlips.append(nlip)
nlifs.append(nlif)
print(f' SI at {100*contrast:.1f}% contrast: {nli:.2f}')
ax.axvline(fbase, **s.lsGrid)
ax.plot(diags[1][0], 1e-3*diags[1][1], **s.lsC2)
ax.plot(diags[0][0], 1e-3*diags[0][1], **s.lsC1)
ax.plot(nlifs[1], 1e-3*nlips[1], **s.psC2)
ax.plot(nlifs[0], 1e-3*nlips[0], **s.psC1)
ax.set_xlim(0, 2*fcutoff)
ax.set_ylim(0, 1.7)
ax.set_xticks_delta(100)
ax.set_yticks_delta(1)
ax.set_xlabel('$f_1 + f_2$', 'Hz')
#ax.set_ylabel(r'$|\chi_2|$', 'kHz')
ax.text(nlifs[1] - 25, 1e-3*nlips[1], f'{100*contrast2:.0f}\\%',
ha='right')
ax.text(nlifs[1] + 35, 1e-3*nlips[1], f'SI={nlis[1]:.1f}')
ax.text(nlifs[0] - 25, 1e-3*nlips[0], f'{100*contrast1:.0f}\\%',
ha='right')
ax.text(nlifs[0] + 35, 1e-3*nlips[0], f'SI={nlis[0]:.1f}')
ax.text(fbase, 1.75, '$r$', ha='center')
if __name__ == '__main__':
"""
from thunderlab.tabledata import TableData
data = TableData('Apteronotus_leptorhynchus-Ampullary-data.csv')
data = data[(data('fcutoff') > 140) & (data('fcutoff') < 160), :]
data = data[(data('nli') > 2) & (data('nli') < 2.5), :]
data = data[(data('respmod2') > 20) & (data('respmod2') < 100), :]
data = data[(data('cvbase') > 0.05) & (data('cvbase') < 0.2), :]
data = data[(data('ratebase') > 100) & (data('ratebase') < 180), :]
for k in range(data.rows()):
print(f'{data[k, "cell"]:<22s} s{data[k, "stimindex"]:02.0f}: {100*data[k, "contrast"]:3g}%, {data[k, "respmod2"]:3.0f}Hz, nli={data[k, "nli"]:5.2f}')
print()
#exit()
"""
print('Example Ampullary cell:', cell_name)
eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(results_path, cell_name)
print(f' baseline firing rate: {rate:.0f}Hz')
print(f' baseline firing CV : {cv:.2f}')
contrast1, time1, stimulus1, spikes1 = load_noise(data_path, cell_name, run1)
contrast2, time2, stimulus2, spikes2 = load_noise(data_path, cell_name, run2)
fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(results_path, cell_name, run1)
fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(results_path, cell_name, run2)
s = plot_style()
s.cell_color1 = s.ampul_color1
s.cell_color2 = s.ampul_color2
s.lsC1 = s.lsA1
s.lsC2 = s.lsA2
s.psC1 = s.psA1
s.psC2 = s.psA2
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.5, 3],
cmsize=(s.plot_width, 0.8*s.plot_width))
fig.subplots_adjust(leftm=8, rightm=9, topm=2, bottomm=4,
wspace=0.4, hspace=0.5)
axi, axp, axr = ax1.subplots(1, 3, width_ratios=[2, 3, 0, 10])
axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.4)
axg = axg.subplots(1, 1, width_ratios=[1, 0.1])
axd = axd.subplots(1, 1, width_ratios=[0.2, 1])
axs = ax3.subplots(1, 4, wspace=0.4)
plot_isih(axi, s, rate, cv, isis, pdf)
plot_response_spectrum(axp, s, eodf, rate, freqs, prr)
plot_response(axr, s, eodf, time1, stimulus1, contrast1, spikes1,
contrast2, spikes2)
plot_gain(axg, s, rate, contrast1, freqs1, gain1,
contrast2, freqs2, gain2, fcutoff1)
pc = plot_chi2(axc1, s, contrast2, freqs2, chi22, fcutoff2, 1.7)
axc1.plot([0, fcutoff2], [0, fcutoff2], zorder=20, **s.lsDiag)
axc1.set_title(f'$c$={100*contrast2:g}\\,\\%',
fontsize='medium', color=s.cell_color2)
pc = plot_chi2(axc2, s, contrast1, freqs1, chi21, fcutoff1, 1.7)
axc2.set_title(f'$c$={100*contrast1:g}\\,\\%',
fontsize='medium', color=s.cell_color1)
axc2.plot([0, fcutoff1], [0, fcutoff1], zorder=20, **s.lsDiag)
plot_colorbar(axc2, pc, 1)
plot_diagonals(axd, s, rate, contrast1, freqs1, chi21,
contrast2, freqs2, chi22, fcutoff1)
fig.common_yticks(axc1, axc2)
fig.tag([axi, axp, axr], xoffs=-3, yoffs=-1)
fig.tag([axg, axc1, axc2, axd], xoffs=-3, yoffs=2)
print('Additional example cells:')
example_cells = [
['2010-11-26-an', 0],
['2011-10-25-ac', 0],
['2011-02-18-ab', 1],
['2014-01-16-aj', 5],
]
for k, (cell, run) in enumerate(example_cells):
eodf, rate, cv, _, _, _, _ = load_baseline(results_path, cell)
fcutoff, contrast, freqs, gain, chi2 = load_spectra(results_path, cell, run)
dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff)
nli, nlif = peakedness(dfreqs, diag, rate, median=False)
print(f' {cell:<22s}: run={run:2d}, fbase={rate:3.0f}Hz, CV={cv:.2f}, SI={nli:3.1f}')
pc = plot_chi2(axs[k], s, contrast, freqs, chi2, fcutoff, 1.2)
axs[k].set_title(f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}', fontsize='medium')
axs[k].text(0.95, 0.9, f'SI($r$)={nli:.1f}', ha='right', zorder=50,
color='white', fontsize='medium',
transform=axs[k].transAxes)
plot_colorbar(axs[-1], pc, 0.4)
fig.common_yticks(axs)
fig.tag(axs, xoffs=-3, yoffs=2)
fig.savefig()

234
dataoverview.py Normal file
View File

@ -0,0 +1,234 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, gaussian_kde
from scipy.stats import mannwhitneyu
from thunderlab.tabledata import TableData
from plotstyle import plot_style, lighter, significance_str
def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color,
nli_thresh):
ax.axhline(nli_thresh, color='k', ls=':', lw=0.5)
"""
for c in np.unique(data('cell')):
xdata = data[data('cell') == c, xcol]
ydata = data[data('cell') == c, ycol]
contrasts = data[data('cell') == c, 'contrast']
idx = np.argsort(contrasts)
if len(idx) > 1:
ax.plot(xdata[idx], ydata[idx], '-k', alpha=0.2, zorder=10)
"""
xmax = ax.get_xlim()[1]
ymax = ax.get_ylim()[1]
mask = (data(xcol) < xmax) & (data(ycol) < ymax)
sc = ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
s=3, clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20)
# color bar:
fig = ax.get_figure()
cax = ax.inset_axes([1.3, 0, 0.04, 1])
cax.set_spines_outward('lrbt', 0)
cb = fig.colorbar(sc, cax=cax)
cb.outline.set_color('none')
cb.outline.set_linewidth(0)
# pdf x-axis:
kde = gaussian_kde(data(xcol), 0.02*xmax/np.std(data(xcol), ddof=1))
xx = np.linspace(0, ax.get_xlim()[1], 400)
pdf = kde(xx)
xax = ax.inset_axes([0, 1.05, 1, 0.2])
xax.show_spines('')
xax.fill_between(xx, pdf, facecolor=color, edgecolor='none')
#xax.plot(xx, np.zeros(len(xx)), clip_on=False, color=color, lw=0.5)
xax.set_ylim(bottom=0)
xax.set_ylim(0, xpdfmax)
# pdf y-axis:
kde = gaussian_kde(data(ycol), 0.02*ymax/np.std(data(ycol), ddof=1))
xx = np.linspace(0, ax.get_ylim()[1], 400)
pdf = kde(xx)
yax = ax.inset_axes([1.05, 0, 0.2, 1])
yax.show_spines('')
yax.fill_betweenx(xx, pdf, facecolor=color, edgecolor='none')
#yax.plot(np.zeros(len(xx)), xx, clip_on=False, color=color, lw=0.5)
yax.set_xlim(left=0)
# threshold:
if 'cvbase' in xcol:
ax.text(xmax, 0.4*ymax, f'{100*np.sum(data(ycol) > nli_thresh)/data.rows():.0f}\\%',
ha='right', va='bottom', fontsize='small')
ax.text(xmax, 0.3, f'{100*np.sum(data(ycol) < nli_thresh)/data.rows():.0f}\\%',
ha='right', va='center', fontsize='small')
# statistics:
r, p = pearsonr(data(xcol), data(ycol))
ax.text(1, 0.9, f'$R={r:.2f}$ **', ha='right',
transform=ax.transAxes, fontsize='small')
#ax.text(1, 0.77, f'{significance_str(p)}', ha='right',
# transform=ax.transAxes, fontsize='small')
if 'cvbase' in xcol:
ax.text(1, 0.77, f'$n={data.rows()}$', ha='right',
transform=ax.transAxes, fontsize='small')
print(f' correlation {xcol:<8s} - {ycol}: r={r:5.2f}, p={p:.2g}')
return cax
def nli_stats(title, data, column, nli_thresh):
print(title)
print(f' nli threshold: {nli_thresh:.1f}')
nrecs = data.rows()
ncells = len(np.unique(data('cell')))
print(f' cells: {ncells}')
print(f' recordings: {nrecs}')
hcells = np.unique(data[data(column) > nli_thresh, 'cell'])
print(f' high nli cells: n={len(hcells):3d}, {100*len(hcells)/ncells:4.1f}%')
print(f' high nli recordings: n={np.sum(data(column) > nli_thresh):3d}, '
f'{100*np.sum(data(column) > nli_thresh)/nrecs:4.1f}%')
nsegs = data('nsegs')
print(f' number of segments: {np.min(nsegs):4.0f} - {np.max(nsegs):4.0f}, median={np.median(nsegs):4.0f}, mean={np.mean(nsegs):4.0f}, std={np.std(nsegs):4.0f}')
def plot_cvbase_nli_punit(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('CV$_{\\rm base}$')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 1.5)
ax.set_ylim(0, 6)
ax.set_yticks_delta(2)
cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 250, 3,
'coolwarm', color, nli_thresh)
cax.set_ylabel('Response mod.', 'Hz')
def plot_cvstim_nli_punit(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('CV$_{\\rm stim}$')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 1.6)
ax.set_ylim(0, 6)
ax.set_xticks_delta(0.5)
ax.set_yticks_delta(2)
#cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 250, 2,
# 'coolwarm', color, nli_thresh)
#cax.set_ylabel('Response mod.', 'Hz')
#cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 1.5, 2,
# 'coolwarm', color, nli_thresh)
#cax.set_ylabel('CV$_{\\rm base}$')
cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 50, 450, 2,
'coolwarm', color, nli_thresh)
cax.set_ylabel('$r$', 'Hz')
def plot_mod_nli_punit(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('Response modulation', 'Hz')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 250)
ax.set_ylim(0, 6)
ax.set_yticks_delta(2)
cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 1.5, 0.016,
'coolwarm', color, nli_thresh)
cax.set_ylabel('CV$_{\\rm base}$')
def plot_cvbase_nli_ampul(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('CV$_{\\rm base}$')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 0.2)
ax.set_ylim(0, 15)
ax.set_xticks_delta(0.1)
ax.set_yticks_delta(5)
cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 80, 20,
'coolwarm', color, nli_thresh)
cax.set_ylabel('Response mod.', 'Hz')
def plot_cvstim_nli_ampul(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('CV$_{\\rm stim}$')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 0.85)
ax.set_ylim(0, 15)
ax.set_xticks_delta(0.2)
ax.set_yticks_delta(5)
#cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 80, 6,
# 'coolwarm', color, nli_thresh)
#cax.set_ylabel('Response mod.', 'Hz')
#cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 0.2, 6,
# 'coolwarm', color, nli_thresh)
#cax.set_ylabel('CV$_{\\rm base}$')
#cax.set_yticks_delta(0.1)
cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 90, 180, 6,
'coolwarm', color, nli_thresh)
cax.set_ylabel('$r$', 'Hz')
cax.set_yticks_delta(30)
def plot_mod_nli_ampul(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('Response modulation', 'Hz')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 80)
ax.set_ylim(0, 15)
ax.set_xticks_delta(20)
ax.set_yticks_delta(5)
cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 0.2, 0.06,
'coolwarm', color, nli_thresh)
cax.set_ylabel('CV$_{\\rm base}$')
cax.set_yticks_delta(0.1)
if __name__ == '__main__':
punit_model = TableData('summarychi2noise.csv', sep=';')
punit_model = punit_model[punit_model('contrast') > 1e-6, :]
punit_data = TableData('Apteronotus_leptorhynchus-Punit-data.csv', sep=';')
ampul_data = TableData('Apteronotus_leptorhynchus-Ampullary-data.csv')
nli_thresh = 1.8
u, p = mannwhitneyu(punit_model('cvbase'), punit_data('cvbase'))
print('CV differs between P-unit models and data:')
print(f' U={u:g}, p={p:g}')
print(f' median model: {np.median(punit_model("cvbase")):.2f}')
print(f' median data: {np.median(punit_data("cvbase")):.2f}')
print()
u, p = mannwhitneyu(punit_model('respmod2'), punit_data('respmod2'))
print('Response modulation differs between P-unit models and data:')
print(f' U={u:g}, p={p:g}')
print(f' median model: {np.median(punit_model("respmod2")):.2f}')
print(f' median data: {np.median(punit_data("respmod2")):.2f}')
print()
u, p = mannwhitneyu(punit_model('dnli100'), punit_data('nli'))
print('NLI does not differ between P-unit models and data:')
print(f' U={u:g}, p={p:g}')
print(f' median model: {np.median(punit_model("dnli100")):.1f}')
print(f' median data: {np.median(punit_data("nli")):.1f}')
print()
s = plot_style()
fig, axs = plt.subplots(3, 3, cmsize=(s.plot_width, 0.75*s.plot_width),
height_ratios=[1, 0, 1, 0.3, 1])
fig.subplots_adjust(leftm=6.5, rightm=13.5, topm=4.5, bottomm=4,
wspace=1.1, hspace=0.6)
nli_stats('P-unit model:', punit_model, 'dnli100', nli_thresh)
axs[0, 0].text(0, 1.35, 'P-unit models',
transform=axs[0, 0].transAxes, color=s.model_color1)
plot_cvbase_nli_punit(axs[0, 0], punit_model, 'dnli100', nli_thresh, s.model_color2)
plot_mod_nli_punit(axs[0, 1], punit_model, 'dnli100', nli_thresh, s.model_color2)
plot_cvstim_nli_punit(axs[0, 2], punit_model, 'dnli100', nli_thresh, s.model_color2)
print()
nli_stats('P-unit data:', punit_data, 'nli', nli_thresh)
axs[1, 0].text(0, 1.35, 'P-unit data',
transform=axs[1, 0].transAxes, color=s.punit_color1)
plot_cvbase_nli_punit(axs[1, 0], punit_data, 'nli', nli_thresh, s.punit_color2)
plot_mod_nli_punit(axs[1, 1], punit_data, 'nli', nli_thresh, s.punit_color2)
plot_cvstim_nli_punit(axs[1, 2], punit_data, 'nli', nli_thresh, s.punit_color2)
print()
nli_stats('Ampullary data:', ampul_data, 'nli', nli_thresh)
axs[2, 0].text(0, 1.35, 'Ampullary data',
transform=axs[2, 0].transAxes, color=s.ampul_color1)
plot_cvbase_nli_ampul(axs[2, 0], ampul_data, 'nli', nli_thresh, s.ampul_color2)
plot_mod_nli_ampul(axs[2, 1], ampul_data, 'nli', nli_thresh, s.ampul_color2)
plot_cvstim_nli_ampul(axs[2, 2], ampul_data, 'nli', nli_thresh, s.ampul_color2)
print()
fig.common_xticks(axs[:2, 0])
fig.common_xticks(axs[:2, 1])
fig.common_xticks(axs[:2, 2])
fig.common_yticks(axs[0, :])
fig.common_yticks(axs[1, :])
fig.common_yticks(axs[2, :])
fig.tag(xoffs=-3.5, yoffs=2)
fig.savefig()

183
modelsusceptcontrasts.py Normal file
View File

@ -0,0 +1,183 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, linregress, gaussian_kde
from thunderlab.tabledata import TableData
from pathlib import Path
from plotstyle import plot_style, labels_params, significance_str
data_path = Path('newdata3')
def sort_files(cell_name, all_files, n):
files = [fn for fn in all_files if '-'.join(fn.stem.split('-')[2:-n]) == cell_name]
if len(files) == 0:
return None, 0
nums = [int(fn.stem.split('-')[-1]) for fn in files]
idxs = np.argsort(nums)
files = [files[i] for i in idxs]
nums = [nums[i] for i in idxs]
return files, nums
def plot_chi2(ax, s, data_file):
data = np.load(data_file)
n = data['n']
alpha = data['alpha']
freqs = data['freqs']
pss = data['pss']
dt_fix = 1 # 0.0005
prss = np.abs(data['prss'])/dt_fix*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1))
ax.set_visible(True)
ax.set_aspect('equal')
i0 = np.argmin(freqs < -300)
i0 = np.argmin(freqs < 0)
i1 = np.argmax(freqs > 300)
if i1 == 0:
i1 = len(freqs)
freqs = freqs[i0:i1]
prss = prss[i0:i1, i0:i1]
vmax = np.quantile(prss, 0.996)
ten = 10**np.floor(np.log10(vmax))
for fac, delta in zip([1, 2, 3, 4, 6, 8, 10],
[0.5, 1, 1, 2, 3, 4, 5]):
if fac*ten >= vmax:
vmax = fac*ten
ten *= delta
break
pc = ax.pcolormesh(freqs, freqs, prss, vmin=0, vmax=vmax,
cmap='viridis', rasterized=True)
if 'noise_frac' in data:
ax.set_title('$c$=0\\,\\%', fontsize='medium')
else:
ax.set_title(f'$c$={100*alpha:g}\\,\\%', fontsize='medium')
ax.set_xlim(0, 300)
ax.set_ylim(0, 300)
ax.set_xticks_delta(100)
ax.set_yticks_delta(100)
ax.set_xlabel('$f_1$', 'Hz')
ax.set_ylabel('$f_2$', 'Hz')
cax = ax.inset_axes([1.04, 0, 0.05, 1])
cax.set_spines_outward('lrbt', 0)
if alpha == 0.1:
cb = fig.colorbar(pc, cax=cax, label=r'$|\chi_2|$ [Hz]')
else:
cb = fig.colorbar(pc, cax=cax)
cb.outline.set_color('none')
cb.outline.set_linewidth(0)
cax.set_yticks_delta(ten)
def plot_chi2_contrasts(axs, s, cell_name):
print(cell_name)
files, nums = sort_files(cell_name,
data_path.glob(f'chi2-split-{cell_name}-*.npz'), 1)
plot_chi2(axs[0], s, files[-1])
for k, alphastr in enumerate(['010', '030', '100']):
files, nums = sort_files(cell_name,
data_path.glob(f'chi2-noisen-{cell_name}-{alphastr}-*.npz'), 2)
plot_chi2(axs[k + 1], s, files[-1])
def plot_nli_cv(ax, s, data, alpha, cells):
data = data[data('contrast') == alpha, :]
r, p = pearsonr(data('cvbase'), data[:, 'dnli'])
l = linregress(data('cvbase'), data[:, 'dnli'])
x = np.linspace(0, 1, 10)
ax.set_visible(True)
ax.set_title(f'$c$={100*alpha:g}\\,\\%', fontsize='medium')
ax.axhline(1, **s.lsLine)
ax.plot(x, l.slope*x + l.intercept, **s.lsGrid)
mask = data('triangle') > 0.5
ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'],
clip_on=False, zorder=30, label='strong', **s.psA1m)
mask = data[:, 'border'] > 0.5
ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'],
zorder=20, label='weak', **s.psA2m)
ax.plot(data[:, 'cvbase'], data[:, 'dnli'], clip_on=False,
zorder=10, label='none', **s.psB1m)
for cell_name in cells:
mask = data[:, 'cell'] == cell_name
color = s.psB1m['color']
if data[mask, 'border']:
color = s.psA2m['color']
elif data[mask, 'triangle']:
color = s.psA1m['color']
ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'],
zorder=40, marker='o', ms=s.psB1m['markersize'],
mfc=color, mec='k', mew=0.8)
ax.set_ylim(0, 8)
ax.set_xlim(0, 1)
ax.set_minor_yticks_delta(1)
ax.set_xlabel('CV$_{\\rm base}$')
ax.set_ylabel('SI')
ax.set_yticks_delta(4)
ax.text(1, 0.9, f'$r={r:.2f}$', transform=ax.transAxes,
ha='right', fontsize='small')
ax.text(1, 0.7, significance_str(p), transform=ax.transAxes,
ha='right', fontsize='small')
if alpha == 0:
ax.legend(loc='upper left', bbox_to_anchor=(1.15, 1.05),
title='triangle', handlelength=0.5,
handletextpad=0.5, labelspacing=0.2)
kde = gaussian_kde(data('dnli'), 0.15/np.std(data('dnli'), ddof=1))
nli = np.linspace(0, 8, 100)
pdf = kde(nli)
dax = ax.inset_axes([1.04, 0, 0.3, 1])
dax.show_spines('')
dax.fill_betweenx(nli, pdf, **s.fsB1a)
dax.plot(pdf, nli, clip_on=False, **s.lsB1m)
def plot_summary_contrasts(axs, s, cells):
nli_thresh = 1.2
data = TableData('summarychi2noise.csv')
plot_nli_cv(axs[0], s, data, 0, cells)
print('split:')
nli_split = data[data('contrast') == 0, 'dnli']
print(f' mean NLI = {np.mean(nli_split):.2f}, stdev = {np.std(nli_split):.2f}')
n = np.sum(nli_split > nli_thresh)
print(f' {n} cells ({100*n/len(nli_split):.1f}%) have NLI > {nli_thresh:.1f}')
print(f' triangle cells have nli >= {np.min(nli_split[data[data("contrast") == 0, "triangle"] > 0.5])}')
print()
for i, a in enumerate([0.01, 0.03, 0.1]):
plot_nli_cv(axs[1 + i], s, data, a, cells)
print(f'contrast {100*a:2g}%:')
cdata = data[data('contrast') == a, :]
nli = cdata('dnli')
r, p = pearsonr(nli_split, nli)
print(f' correlation with split: r={r:.2f}, p={p:.1e}')
print(f' mean NLI = {np.mean(nli):.2f}, stdev = {np.std(nli):.2f}')
n = np.sum(nli > nli_thresh)
print(f' {n} cells ({100*n/len(nli):.1f}%) have NLI > {nli_thresh:.1f}')
print( ' CVs:', cdata[nli > nli_thresh, 'cvbase'])
print( ' names:', cdata[nli > nli_thresh, 'cell'])
print()
print('lowest baseline CV:', np.unique(data('cvbase'))[:3])
if __name__ == '__main__':
cells = ['2017-07-18-ai-invivo-1', # strong triangle
'2012-12-13-ao-invivo-1', # triangle
'2012-12-20-ac-invivo-1', # weak border triangle
'2013-01-08-ab-invivo-1'] # no triangle
s = plot_style()
#labels_params(xlabelloc='right', ylabelloc='top')
fig, axs = plt.subplots(6, 4, cmsize=(s.plot_width, 0.95*s.plot_width),
height_ratios=[1, 1, 1, 1, 0, 1])
fig.subplots_adjust(leftm=7, rightm=8, topm=2, bottomm=3.5,
wspace=1, hspace=0.7)
for ax in axs.flat:
ax.set_visible(False)
for k in range(len(cells)):
plot_chi2_contrasts(axs[k], s, cells[k])
for k in range(4):
fig.common_yticks(axs[k, :])
fig.common_xticks(axs[:4, k])
plot_summary_contrasts(axs[5], s, cells)
fig.common_yticks(axs[5, :])
fig.tag(axs, xoffs=-4.5, yoffs=1.8)
fig.savefig()

199
modelsusceptlown.py Normal file
View File

@ -0,0 +1,199 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, linregress, gaussian_kde
from thunderlab.tabledata import TableData
from pathlib import Path
from plotstyle import plot_style, labels_params, significance_str
data_path = Path('newdata3')
def sort_files(cell_name, all_files, n):
files = [fn for fn in all_files if '-'.join(fn.stem.split('-')[2:-n]) == cell_name]
if len(files) == 0:
return None, 0
nums = [int(fn.stem.split('-')[-1]) for fn in files]
idxs = np.argsort(nums)
files = [files[i] for i in idxs]
nums = [nums[i] for i in idxs]
return files, nums
def plot_chi2(ax, s, data_file):
data = np.load(data_file)
n = data['n']
alpha = data['alpha']
freqs = data['freqs']
pss = data['pss']
dt_fix = 1 # 0.0005
prss = np.abs(data['prss'])/dt_fix*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1))
ax.set_visible(True)
ax.set_aspect('equal')
i0 = np.argmin(freqs < -300)
i0 = np.argmin(freqs < 0)
i1 = np.argmax(freqs > 300)
if i1 == 0:
i1 = len(freqs)
freqs = freqs[i0:i1]
prss = prss[i0:i1, i0:i1]
vmax = np.quantile(prss, 0.996)
ten = 10**np.floor(np.log10(vmax))
for fac, delta in zip([1, 2, 3, 4, 6, 8, 10],
[0.5, 1, 1, 2, 3, 4, 5]):
if fac*ten >= vmax:
vmax = fac*ten
ten *= delta
break
pc = ax.pcolormesh(freqs, freqs, prss, vmin=0, vmax=vmax,
cmap='viridis', rasterized=True)
ns = f'$N={n}$' if n <= 100 else f'$N=10^{np.log10(n):.0f}$'
if 'noise_frac' in data:
ax.set_title(f'$c$=0\\,\\%, {ns}', fontsize='medium')
else:
ax.set_title(f'$c$={100*alpha:g}\\,\\%, {ns}', fontsize='medium')
ax.set_xlim(0, 300)
ax.set_ylim(0, 300)
ax.set_xticks_delta(100)
ax.set_yticks_delta(100)
ax.set_xlabel('$f_1$', 'Hz')
ax.set_ylabel('$f_2$', 'Hz')
cax = ax.inset_axes([1.04, 0, 0.05, 1])
cax.set_spines_outward('lrbt', 0)
if alpha == 0.1:
cb = fig.colorbar(pc, cax=cax, label=r'$|\chi_2|$ [Hz]')
else:
cb = fig.colorbar(pc, cax=cax)
cb.outline.set_color('none')
cb.outline.set_linewidth(0)
cax.set_yticks_delta(ten)
def plot_chi2_contrasts(axs, s, cell_name, n=None):
print(cell_name)
files, nums = sort_files(cell_name,
data_path.glob(f'chi2-split-{cell_name}-*.npz'), 1)
idx = -1 if n is None else nums.index(n)
plot_chi2(axs[0], s, files[idx])
for k, alphastr in enumerate(['010', '030', '100']):
files, nums = sort_files(cell_name,
data_path.glob(f'chi2-noisen-{cell_name}-{alphastr}-*.npz'), 2)
idx = -1 if n is None else nums.index(n)
plot_chi2(axs[k + 1], s, files[idx])
def plot_nli_diags(ax, s, data, alphax, alphay, xthresh, ythresh, cell_name):
datax = data[data('contrast') == alphax, :]
datay = data[data('contrast') == alphay, :]
nlix = datax('dnli')
nliy = datay('dnli100')
nfp = np.sum((nliy > ythresh) & (nlix < xthresh))
ntp = np.sum((nliy > ythresh) & (nlix > xthresh))
ntn = np.sum((nliy < ythresh) & (nlix < xthresh))
nfn = np.sum((nliy < ythresh) & (nlix > xthresh))
print(f' {ntp:2d} ({100*ntp/len(nlix):2.0f}%) true positive')
print(f' {nfp:2d} ({100*nfp/len(nlix):2.0f}%) false positive')
print(f' {ntn:2d} ({100*ntn/len(nlix):2.0f}%) true negative')
print(f' {nfn:2d} ({100*nfn/len(nlix):2.0f}%) false negative')
r, p = pearsonr(nlix, nliy)
l = linregress(nlix, nliy)
x = np.linspace(0, 10, 10)
ax.set_visible(True)
ax.set_title(f'$c$={100*alphay:g}\\,\\%', fontsize='medium')
ax.plot(x, x, **s.lsLine)
ax.plot(x, l.slope*x + l.intercept, **s.lsGrid)
ax.axhline(ythresh, **s.lsLine)
ax.axvline(xthresh, 0, 0.5, **s.lsLine)
if alphax == 0:
mask = datax('triangle') > 0.5
ax.plot(nlix[mask], nliy[mask], zorder=30, label='strong', **s.psA1m)
mask = datax('border') > 0.5
ax.plot(nliy[mask], nliy[mask], zorder=20, label='weak', **s.psA2m)
ax.plot(nlix, nliy, zorder=10, label='none', **s.psB1m)
# mark cell:
mask = datax('cell') == cell_name
color = s.psB1m['color']
if alphax == 0:
if datax[mask, 'border']:
color = s.psA2m['color']
elif datax[mask, 'triangle']:
color = s.psA1m['color']
ax.plot(nlix[mask], nliy[mask], zorder=40, marker='o',
ms=s.psB1m['markersize'], mfc=color, mec='k', mew=0.8)
box = dict(boxstyle='square,pad=0.1', fc='white', ec='none')
ax.text(1.0, 0.0, f'{ntn}', ha='right', fontsize='small', bbox=box)
ax.text(7.5, 0.0, f'{nfn}', ha='right', fontsize='small', bbox=box)
ax.text(1.0, 3.7, f'{nfp}', ha='right', fontsize='small', bbox=box)
ax.text(7.5, 3.7, f'{ntp}', ha='right', fontsize='small', bbox=box)
ax.set_ylim(0, 9)
ax.set_xlim(0, 9)
n = datax[0, 'nsegs']
if alphax == 0:
ax.set_xlabel(f'SI, $c=0$, $N=10^{np.log10(n):.0f}$')
else:
ax.set_xlabel(f'SI, $N=10^{np.log10(n):.0f}$')
ax.set_ylabel('SI, $N=100$')
ax.set_xticks_delta(4)
ax.set_yticks_delta(4)
ax.set_minor_xticks_delta(1)
ax.set_minor_yticks_delta(1)
ax.text(0, 0.9, f'$r={r:.2f}$', transform=ax.transAxes, fontsize='small')
ax.text(0, 0.7, significance_str(p), transform=ax.transAxes,
fontsize='small')
if alphax == 0 and alphay == 0.01:
ax.legend(loc='upper left', bbox_to_anchor=(-1.5, 1),
title='triangle', handlelength=0.5,
handletextpad=0.5, labelspacing=0.2)
kde = gaussian_kde(nliy, 0.15/np.std(nliy, ddof=1))
nli = np.linspace(0, 8, 100)
pdf = kde(nli)
dax = ax.inset_axes([1.04, 0, 0.3, 1])
dax.show_spines('')
dax.fill_betweenx(nli, pdf, **s.fsB1a)
dax.plot(pdf, nli, clip_on=False, **s.lsB1m)
def plot_summary_contrasts(axs, s, xthresh, ythresh, cell_name):
print(f'against contrast with thresholds: x={xthresh} and y={ythresh}')
data = TableData('summarychi2noise.csv')
for i, a in enumerate([0.01, 0.03, 0.1]):
print(f'contrast {100*a:2g}%:')
plot_nli_diags(axs[1 + i], s, data, a, a, xthresh, ythresh, cell_name)
print()
def plot_summary_diags(axs, s, xthresh, ythresh, cell_name):
print(f'against split with thresholds: x={xthresh} and y={ythresh}')
data = TableData('summarychi2noise.csv')
for i, a in enumerate([0.01, 0.03, 0.1]):
print(f'contrast {100*a:2g}%:')
plot_nli_diags(axs[1 + i], s, data, 0, a, xthresh, ythresh, cell_name)
if __name__ == '__main__':
xthresh = 1.2
ythresh = 1.8
s = plot_style()
fig, axs = plt.subplots(6, 4, cmsize=(s.plot_width, 0.85*s.plot_width),
height_ratios=[1, 1, 0, 1, 0, 1])
fig.subplots_adjust(leftm=7, rightm=8, topm=2, bottomm=3.5,
wspace=1, hspace=1)
for ax in axs.flat:
ax.set_visible(False)
cell_name = '2012-12-21-ak-invivo-1'
plot_chi2_contrasts(axs[0], s, cell_name)
plot_chi2_contrasts(axs[1], s, cell_name, 10)
for k in range(2):
fig.common_yticks(axs[k, :])
for k in range(4):
fig.common_xticks(axs[:2, k])
plot_summary_contrasts(axs[3], s, xthresh, ythresh, cell_name)
plot_summary_diags(axs[5], s, xthresh, ythresh, cell_name)
fig.common_yticks(axs[3, 1:])
fig.common_yticks(axs[5, 1:])
fig.tag(axs, xoffs=-4.5, yoffs=1.8)
axs[1, 0].set_visible(False)
#plt.show()
fig.savefig()

165
modelsusceptovern.py Normal file
View File

@ -0,0 +1,165 @@
import numpy as np
from scipy.stats import linregress
import matplotlib.pyplot as plt
from pathlib import Path
from plotstyle import plot_style, labels_params
data_path = Path('newdata3')
def sort_files(cell_name, all_files, n):
files = [fn for fn in all_files if '-'.join(fn.stem.split('-')[2:-n]) == cell_name]
if len(files) == 0:
return None, 0
nums = [int(fn.stem.split('-')[-1]) for fn in files]
idxs = np.argsort(nums)
files = [files[i] for i in idxs]
nums = [nums[i] for i in idxs]
return files, nums
def plot_chi2(ax, s, data_file):
data = np.load(data_file)
n = data['n']
alpha = data['alpha']
freqs = data['freqs']
pss = data['pss']
dt_fix = 1 # 0.0005
prss = np.abs(data['prss'])/dt_fix*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1))
ax.set_visible(True)
ax.set_aspect('equal')
i0 = np.argmin(freqs < -300)
i0 = np.argmin(freqs < 0)
i1 = np.argmax(freqs > 300)
if i1 == 0:
i1 = len(freqs)
freqs = freqs[i0:i1]
prss = prss[i0:i1, i0:i1]
vmax = np.quantile(prss, 0.996)
ten = 10**np.floor(np.log10(vmax))
for fac, delta in zip([1, 2, 3, 4, 6, 8, 10],
[0.5, 1, 1, 2, 3, 4, 5]):
if fac*ten >= vmax:
vmax = fac*ten
ten *= delta
break
pc = ax.pcolormesh(freqs, freqs, prss, vmin=0, vmax=vmax,
cmap='viridis', rasterized=True)
ax.set_title(f'$N=10^{np.log10(n):.0f}$', fontsize='medium')
ax.set_xlim(0, 300)
ax.set_ylim(0, 300)
ax.set_xticks_delta(300)
ax.set_minor_xticks(3)
ax.set_yticks_delta(300)
ax.set_minor_yticks(3)
ax.set_xlabel('$f_1$', 'Hz')
ax.set_ylabel('$f_2$', 'Hz')
cax = ax.inset_axes([1.04, 0, 0.05, 1])
cax.set_spines_outward('lrbt', 0)
cb = fig.colorbar(pc, cax=cax)
cb.outline.set_color('none')
cb.outline.set_linewidth(0)
cax.set_yticks_delta(ten)
def plot_overn(ax, s, files, nmax=1e6, title=False):
ns = []
stats = []
for fname in files:
data = np.load(fname)
if not 'n' in data:
return
n = data['n']
if nmax is not None and n > nmax:
continue
alpha = data['alpha']
freqs = data['freqs']
pss = data['pss']
dt_fix = 1 # 0.0005
chi2 = np.abs(data['prss'])/dt_fix*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1))
ns.append(n)
i0 = np.argmin(freqs < 0)
i1 = np.argmax(freqs > 300)
if i1 == 0:
i1 = len(freqs)
chi2 = chi2[i0:i1, i0:i1]
stats.append(np.quantile(chi2, [0, 0.001, 0.05, 0.25, 0.5,
0.75, 0.95, 0.998, 1.0]))
ns = np.array(ns)
stats = np.array(stats)
indx = np.argsort(ns)
ns = ns[indx]
stats = stats[indx]
ax.set_visible(True)
ax.plot(ns, stats[:, 7], '0.5', lw=1, zorder=50, label='99.8\\%')
ax.fill_between(ns, stats[:, 2], stats[:, 6], fc='0.85', zorder=40, label='5--95\\%')
ax.fill_between(ns, stats[:, 3], stats[:, 5], fc='0.5', zorder=45, label='25-75\\%')
ax.plot(ns, stats[:, 4], zorder=50, label='median', **s.lsSpine)
#ax.plot(ns, stats[:, 8], '0.0')
if title:
if 'noise_frac' in data:
ax.set_title('$c$=0\\,\\%', fontsize='medium')
else:
ax.set_title(f'$c$={100*alpha:g}\\,\\%', fontsize='medium')
ax.set_xlim(1e1, nmax)
ax.set_xscale('log')
ax.set_yscale('log')
ax.set_yticks_log(numticks=3)
if nmax > 1e6:
ax.set_ylim(3e-1, 5e3)
ax.set_minor_yticks_log(numticks=5)
ax.set_xticks_log(numticks=4)
ax.set_minor_xticks_log(numticks=8)
else:
ax.set_ylim(5e0, 1e4)
ax.set_minor_yticks_log(numticks=5)
ax.set_xticks_log(numticks=3)
ax.set_minor_xticks_log(numticks=6)
ax.set_xlabel('segments')
ax.set_ylabel('$|\\chi_2|$ [Hz]')
if alpha == 0.10:
ax.legend(loc='upper left', bbox_to_anchor=(1.4, 1.3),
markerfirst=False, title='$|\\chi_2|$ percentiles')
def plot_chi2_overn(axs, s, cell_name):
print(cell_name)
files, nums = sort_files(cell_name,
data_path.glob(f'chi2-split-{cell_name}-*.npz'), 1)
for k, n in enumerate([1e1, 1e2, 1e3, 1e6]):
plot_chi2(axs[k], s, files[nums.index(int(n))])
plot_overn(axs[-1], s, files)
if __name__ == '__main__':
cells = ['2017-07-18-ai-invivo-1', # strong triangle
'2012-12-13-ao-invivo-1', # triangle
'2012-12-20-ac-invivo-1', # weak border triangle
'2013-01-08-ab-invivo-1'] # no triangle
s = plot_style()
fig, axs = plt.subplots(6, 6, cmsize=(s.plot_width, 0.9*s.plot_width),
width_ratios=[1, 1, 1, 1, 0, 1],
height_ratios=[1, 1, 1, 1, 0, 1])
fig.subplots_adjust(leftm=8, rightm=0.5, topm=2, bottomm=3.5,
wspace=1, hspace=0.8)
for ax in axs.flat:
ax.set_visible(False)
for k in range(len(cells)):
plot_chi2_overn(axs[k], s, cells[k])
cell_name = cells[0]
files, nums = sort_files(cell_name,
data_path.glob(f'chi2-split-{cell_name}-*.npz'), 1)
plot_overn(axs[-1, 0], s, files, 1e7, True)
for k, alphastr in enumerate(['010', '030', '100']):
files, nums = sort_files(cell_name,
data_path.glob(f'chi2-noisen-{cell_name}-{alphastr}-*.npz'), 2)
plot_overn(axs[-1, k + 1], s, files, 1e7, True)
for k in range(4):
fig.common_yticks(axs[k, :4])
fig.common_xticks(axs[:4, k])
fig.common_xticks(axs[:4, -1])
fig.align_ylabels(axs[:4, -1], dist=12)
fig.common_yticks(axs[-1, :4])
fig.tag(axs, xoffs=-2.5, yoffs=1.8)
fig.savefig()

126
plotstyle.py Normal file
View File

@ -0,0 +1,126 @@
import matplotlib as mpl
import plottools.plottools as pt
from plottools.spines import spines_params
from plottools.labels import labels_params
from plottools.colors import lighter, darker
def significance_str(p):
if p > 0.05:
return f'$p={p:.2f}$'
elif p > 0.01:
return '$p<0.05$'
elif p > 0.001:
return '$p<0.01$'
else:
return '$p<0.001$'
def plot_style():
palette = pt.palettes['muted']
lwthick = 1.0
lwthin = 0.5
lwspines = 1.0
names = ['A1', 'A2', 'A3',
'B1', 'B2', 'B3', 'B4',
'C1', 'C2', 'C3', 'C4']
colors = [palette['red'], palette['orange'], palette['yellow'],
palette['blue'], palette['purple'], palette['magenta'], palette['lightblue'],
palette['lightgreen'], palette['green'], palette['darkgreen'], palette['cyan']]
dashes = ['-', '-', '-',
'-', '-', '-', '-.',
'-', '-', '-', '-']
markers = [('o', 1.0), ('p', 1.1), ('h', 1.1),
((3, 1, 60), 1.25), ((3, 1, 0), 1.25), ((3, 1, 90), 1.25), ((3, 1, 30), 1.25),
('s', 0.9), ('D', 0.85), ('*', 1.6), ((4, 1, 45), 1.4)]
class ns: pass
ns.colors = palette
ns.lwthick = lwthick
ns.lwthin = lwthin
ns.plot_width = 16.5
pt.make_linepointfill_styles(ns, names, colors, dashes, markers,
lwthick=lwthick, lwthin=lwthin,
markerlarge=5, markersmall=4,
mec=0.0, mew=0.5, fillalpha=0.4)
pt.make_line_styles(ns, 'ls', 'Spine', '', palette['black'], '-',
lwspines, clip_on=False)
pt.make_line_styles(ns, 'ls', 'Grid', '', palette['gray'], '--',
0.7*lwthin)
pt.make_line_styles(ns, 'ls', 'Dotted', '', palette['gray'], ':',
0.7*lwthin)
pt.make_line_styles(ns, 'ls', 'Marker', '', palette['black'], '-',
lwthick, clip_on=False)
pt.make_line_styles(ns, 'ls', 'Line', '', palette['black'], '-',
lwthin)
pt.make_line_styles(ns, 'ls', 'EOD', '', palette['gray'], '-', lwthin)
pt.make_line_styles(ns, 'ls', 'AM', '', palette['red'], '-', lwthick)
ns.model_color1 = palette['purple']
ns.model_color2 = lighter(ns.model_color1, 0.6)
ns.punit_color1 = palette['blue']
ns.punit_color2 = lighter(ns.punit_color1, 0.6)
ns.ampul_color1 = palette['green']
ns.ampul_color2 = lighter(ns.ampul_color1, 0.6)
pt.make_line_styles(ns, 'ls', 'M%d', '', [ns.model_color1, ns.model_color2], '-', lwthick)
pt.make_point_styles(ns, 'ps', 'M%d', '', [ns.model_color1, ns.model_color2], '-', markers=('o', 1), markersizes=4)
pt.make_line_styles(ns, 'ls', 'P%d', '', [ns.punit_color1, ns.punit_color2], '-', lwthick)
pt.make_point_styles(ns, 'ps', 'P%d', '', [ns.punit_color1, ns.punit_color2], '-', markers=('o', 1), markersizes=4)
pt.make_line_styles(ns, 'ls', 'A%d', '', [ns.ampul_color1, ns.ampul_color2], '-', lwthick)
pt.make_point_styles(ns, 'ps', 'A%d', '', [ns.ampul_color1, ns.ampul_color2], '-', markers=('o', 1), markersizes=4)
pt.make_line_styles(ns, 'ls', 'Diag', '', palette['white'], '--', lwthin)
pt.arrow_style(ns, 'Line', dist=3.0, style='>', shrink=0, lw=0.6,
color=palette['black'], head_length=4, head_width=4,
bbox=dict(boxstyle='round,pad=0.1', facecolor='white',
edgecolor='none', alpha=1.0))
pt.arrow_style(ns, 'Hertz', dist=3.0, style='>', shrink=0, lw=0.6,
color=palette['black'], head_length=3, head_width=3,
heads='<>', text='%.0f\u2009Hz', rotation='vertical',
fontsize='x-small',
bbox=dict(boxstyle='round,pad=0.1', facecolor='white',
edgecolor='none', alpha=0.6))
pt.arrow_style(ns, 'Point', dist=3.0, style='>>', shrink=0,
lw=1.1, color=palette['black'], head_length=5,
head_width=4)
pt.arrow_style(ns, 'PointSmall', dist=3.0, style='>>', shrink=0,
lw=1, color=palette['black'], head_length=5,
head_width=5, fontsize='small')
pt.arrow_style(ns, 'Marker', dist=3.0, style='>>', shrink=0,
lw=0.9, color=palette['black'], head_length=5,
head_width=5, fontsize='small', ha='center',
va='center',
bbox=dict(boxstyle='round, pad=0.1, rounding_size=0.4',
facecolor=palette['white'],
edgecolor='none', alpha=0.6))
# rc settings:
mpl.rcdefaults()
pt.axes_params(0.0, 0.0, 0.0, color='none')
cmcolors = [pt.lighter(palette['yellow'], 0.2),
pt.lighter(palette['orange'], 0.5), palette['orange'],
palette['red'], palette['red']]
cmvalues = [0.0, 0.3, 0.7, 0.95, 1.0]
pt.colormap('YR', cmcolors, cmvalues)
cycle_colors = ['blue', 'red', 'orange', 'lightgreen', 'magenta',
'yellow', 'cyan', 'pink']
pt.colors_params(palette, cycle_colors, 'RdYlBu')
pt.figure_params(palette['white'], format='pdf', compression=6,
fonttype=3, stripfonts=True)
pt.labels_params('{label} [{unit}]', labelsize='medium', labelpad=6)
pt.legend_params(fontsize='small', frameon=False, borderpad=0.0,
handlelength=1.5, handletextpad=1,
numpoints=1, scatterpoints=1,
labelspacing=0.5, columnspacing=0.5)
pt.scalebars_params(format_large='%.0f', format_small='%.1f',
lw=2.2, color=palette['black'], capsize=0,
clw=0.5, font=dict(fontsize='medium',
fontstyle='normal'))
pt.spines_params(spines='lb', spines_offsets={'lrtb': 3},
spines_bounds={'lrtb': 'full'})
pt.tag_params(xoffs='auto', yoffs='auto', label='%A',
minor_label=r'%A$_{\text{%mi}}$',
font=dict(fontsize='x-large', fontstyle='normal',
fontweight='normal'))
pt.text_params(font_size=7, font_family='sans-serif', latex=True,
preamble=('p:sfmath', 'p:marvosym', 'p:xfrac',
'p:SIunits'))
pt.ticks_params(xtick_dir='out', xtick_size=3)
return ns

275
punitexamplecell.py Normal file
View File

@ -0,0 +1,275 @@
import sys
sys.path.insert(0, 'ephys') # for analysing data
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from spectral import diag_projection, peakedness
from plotstyle import plot_style
cell_name = '2020-10-27-ag-invivo-1'
run1 = 0
run2 = 1
base_path = Path('ephys')
data_path = base_path / 'data'
results_path = base_path / 'results'
def load_baseline(path, cell_name):
d = path / f'{cell_name}-baseline.npz'
data = np.load(d)
['eodf', 'isis', 'isih', 'lags', 'corrs', 'freqs', 'prr']
eodf = float(data['eodf'])
rate = float(data['ratebase/Hz'])
cv = float(data['cvbase'])
isis = data['isis']
pdf = data['isih']
freqs = data['freqs']
prr = data['prr']
return eodf, rate, cv, isis, pdf, freqs, prr
def load_noise(path, cell_name, run):
data = np.load(path / f'{cell_name}-spectral-data-s{run:02d}.npz')
contrast = data['contrast']
time = data['time']
stimulus = data['stimulus']
name = str(data['stimulus_name'])
fcutoff = float(name.lower().replace('blwn', '').replace('inputarr_', '').replace('gwn', '').split('h')[0])
spikes = []
for k in range(1000):
key = f'spikes_{k:03d}'
if not key in data.keys():
break
spikes.append(data[key])
return contrast, time, stimulus, spikes
def load_spectra(path, cell_name, run=None):
if run is None:
data = np.load(cell_name)
else:
d = list(path.glob(f'{cell_name}-spectral*-s{run:02d}.npz'))
data = np.load(d[0])
contrast = float(data['alpha'])
fcutoff = float(data['fcutoff'])
freqs = data['freqs']
pss = data['pss']
prs = data['prs']
prss = data['prss']
nsegs = int(data['n'])
gain = np.abs(prs)/pss
chi2 = np.abs(prss)*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1))
return fcutoff, contrast, freqs, gain, chi2
def plot_isih(ax, s, rate, cv, isis, pdf):
ax.show_spines('b')
ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1)
ax.set_xlim(0, 8)
ax.set_xticks_delta(2)
ax.set_xlabel('ISI', 'ms')
ax.text(0, 1.08, 'P-unit:', transform=ax.transAxes, color=s.cell_color1,
fontsize='large')
ax.text(0.6, 1.08, f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}',
transform=ax.transAxes)
def plot_response_spectrum(ax, s, eodf, rate, freqs, prr):
rate_i = np.argmax(prr[freqs < 0.7*eodf])
eod_i = np.argmax(prr[freqs > 500]) + np.argmax(freqs > 500)
power_db = 10*np.log10(prr/np.max(prr))
ax.show_spines('b')
mask = freqs < 890
ax.plot(freqs[mask], power_db[mask], **s.lsC1)
ax.plot(freqs[eod_i], power_db[eod_i], **s.psA1c)
ax.plot(freqs[rate_i], power_db[rate_i], **s.psA2c)
ax.set_xlim(0, 900)
ax.set_ylim(-25, 5)
ax.set_xticks_delta(300)
ax.set_xlabel('$f$', 'Hz')
ax.text(freqs[eod_i], power_db[eod_i] + 2, '$f_{\\rm EOD}$',
ha='center')
ax.text(freqs[rate_i], power_db[rate_i] + 2, '$r$',
ha='center')
ax.yscalebar(1.05, 0, 10, 'dB', ha='right')
def plot_response(ax, s, eodf, time1, stimulus1, contrast1, spikes1, contrast2, spikes2):
t0 = 0.3
t1 = 0.4
#print(len(spikes1), len(spikes2))
maxtrials = 8
trials = np.arange(maxtrials)
ax.show_spines('')
ax.eventplot(spikes1[2:2+maxtrials], lineoffsets=trials - maxtrials + 1,
linelength=0.8, linewidths=1, color=s.cell_color1)
ax.eventplot(spikes2[2:2+maxtrials], lineoffsets=trials - 2*maxtrials,
linelength=0.8, linewidths=1, color=s.cell_color2)
am = 1 + contrast1*stimulus1
eod = np.sin(2*np.pi*eodf*time1) * am
ax.plot(time1, 4*eod + 7, **s.lsEOD)
ax.plot(time1, 4*am + 7, **s.lsAM)
ax.set_xlim(t0, t1)
ax.set_ylim(-2*maxtrials - 0.5, 14)
ax.xscalebar(1, -0.05, 0.01, None, '10\\,ms', ha='right')
ax.text(t1 + 0.003, -0.5*maxtrials, f'${100*contrast1:.0f}$\\,\\%',
va='center', color=s.cell_color1)
ax.text(t1 + 0.003, -1.55*maxtrials, f'${100*contrast2:.0f}$\\,\\%',
va='center', color=s.cell_color2)
def plot_gain(ax, s, contrast1, freqs1, gain1, contrast2, freqs2, gain2, fcutoff):
ax.plot(freqs2, gain2, label=f'{100*contrast2:.0f}', **s.lsC2)
ax.plot(freqs1, gain1, label=f'{100*contrast1:.0f}', **s.lsC1)
ax.set_xlim(0, fcutoff)
ax.set_ylim(0, 800)
ax.set_xticks_delta(100)
ax.set_xlabel('$f$', 'Hz')
ax.set_ylabel(r'$|\chi_1|$', 'Hz')
def plot_colorbar(ax, pc, dc=None):
cax = ax.inset_axes([1.04, 0, 0.05, 1])
cax.set_spines_outward('lrbt', 0)
cb = cax.get_figure().colorbar(pc, cax=cax, label=r'$|\chi_2|$ [kHz]')
cb.outline.set_color('none')
cb.outline.set_linewidth(0)
if dc is not None:
cax.set_yticks_delta(dc)
def plot_chi2(ax, s, contrast, freqs, chi2, fcutoff, vmax):
ax.set_aspect('equal')
if vmax is None:
vmax = np.quantile(1e-3*chi2, 0.99)
pc = ax.pcolormesh(freqs, freqs, 1e-3*chi2, vmin=0, vmax=vmax,
cmap='viridis', rasterized=True, zorder=10)
ax.set_xlim(0, fcutoff)
ax.set_ylim(0, fcutoff)
df = 100 if fcutoff == 300 else 50
ax.set_xticks_delta(df)
ax.set_yticks_delta(df)
ax.set_xlabel('$f_1$', 'Hz')
ax.set_ylabel('$f_2$', 'Hz')
return pc
def plot_diagonals(ax, s, fbase, contrast1, freqs1, chi21, contrast2, freqs2, chi22, fcutoff):
diags = []
nlis = []
nlips = []
nlifs = []
for contrast, freqs, chi2 in [[contrast1, freqs1, chi21], [contrast2, freqs2, chi22]]:
dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff)
diags.append([dfreqs, diag])
nli, nlif = peakedness(dfreqs, diag, fbase, median=False)
nlip = diag[np.argmin(np.abs(dfreqs - nlif))]
nlis.append(nli)
nlips.append(nlip)
nlifs.append(nlif)
print(f' SI at {100*contrast:.1f}% contrast: {nli:.2f}')
ax.axvline(fbase, **s.lsGrid)
ax.plot(diags[1][0], 1e-3*diags[1][1], **s.lsC2)
ax.plot(diags[0][0], 1e-3*diags[0][1], **s.lsC1)
ax.plot(nlifs[1], 1e-3*nlips[1], **s.psC2)
ax.plot(nlifs[0], 1e-3*nlips[0], **s.psC1)
ax.set_xlim(0, 2*fcutoff)
ax.set_ylim(0, 4.2)
ax.set_xticks_delta(300)
ax.set_yticks_delta(1)
ax.set_xlabel('$f_1 + f_2$', 'Hz')
#ax.set_ylabel(r'$|\chi_2|$', 'kHz')
ax.text(nlifs[1] - 50, 1e-3*nlips[1], f'{100*contrast2:.0f}\\%',
ha='right')
ax.text(nlifs[1] + 70, 1e-3*nlips[1], f'SI={nlis[1]:.1f}')
ax.text(nlifs[0] - 50, 1e-3*nlips[0], f'{100*contrast1:.0f}\\%',
ha='right')
ax.text(nlifs[0] + 70, 1e-3*nlips[0], f'SI={nlis[0]:.1f}')
ax.text(fbase, 4.3, '$r$', ha='center')
if __name__ == '__main__':
print('Example P-unit:', cell_name)
eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(results_path, cell_name)
print(f' baseline firing rate: {rate:.0f}Hz')
print(f' baseline firing CV : {cv:.2f}')
contrast1, time1, stimulus1, spikes1 = load_noise(data_path, cell_name, run1)
contrast2, time2, stimulus2, spikes2 = load_noise(data_path, cell_name, run2)
fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(results_path, cell_name, run1)
fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(results_path, cell_name, run2)
s = plot_style()
s.cell_color1 = s.punit_color1
s.cell_color2 = s.punit_color2
s.lsC1 = s.lsP1
s.lsC2 = s.lsP2
s.psC1 = s.psP1
s.psC2 = s.psP2
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.5, 3],
cmsize=(s.plot_width, 0.8*s.plot_width))
fig.subplots_adjust(leftm=8, rightm=9, topm=2, bottomm=4,
wspace=0.4, hspace=0.5)
axi, axp, axr = ax1.subplots(1, 3, width_ratios=[2, 3, 0, 10])
axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.4)
axg = axg.subplots(1, 1, width_ratios=[1, 0.1])
axd = axd.subplots(1, 1, width_ratios=[0.2, 1])
axs = ax3.subplots(1, 4, wspace=0.4)
plot_isih(axi, s, rate, cv, isis, pdf)
plot_response_spectrum(axp, s, eodf, rate, freqs, prr)
plot_response(axr, s, eodf, time1, stimulus1, contrast1, spikes1,
contrast2, spikes2)
plot_gain(axg, s, contrast1, freqs1, gain1,
contrast2, freqs2, gain2, fcutoff1)
pc = plot_chi2(axc1, s, contrast2, freqs2, chi22, fcutoff2, 4)
axc1.plot([0, fcutoff2], [0, fcutoff2], zorder=20, **s.lsDiag)
axc1.set_title(f'$c$={100*contrast2:g}\\,\\%',
fontsize='medium', color=s.cell_color2)
pc = plot_chi2(axc2, s, contrast1, freqs1, chi21, fcutoff1, 4)
axc2.set_title(f'$c$={100*contrast1:g}\\,\\%',
fontsize='medium', color=s.cell_color1)
axc2.plot([0, fcutoff1], [0, fcutoff1], zorder=20, **s.lsDiag)
plot_colorbar(axc2, pc)
plot_diagonals(axd, s, rate, contrast1, freqs1, chi21,
contrast2, freqs2, chi22, fcutoff1)
fig.common_yticks(axc1, axc2)
fig.tag([axi, axp, axr], xoffs=-3, yoffs=-1)
fig.tag([axg, axc1, axc2, axd], xoffs=-3, yoffs=2)
print('Additional example cells:')
example_cells = [
['2021-06-18-ae-invivo-1', 3], # 98Hz, 1%, ok
['2012-03-30-ah', 2], # 177Hz, 2.5%, 2.0, nice
##['2012-07-03-ak', 0], # 120Hz, 2.5%, 1.8, broader
##['2012-12-20-ac', 0], # 213Hz, 2.5%, 2.1, ok
#['2017-07-18-ai-invivo-1', 1], # 78Hz, 5%, 2.3, weak
##['2019-06-28-ae', 0], # 477Hz, 10%, 2.6, weak
##['2020-10-27-aa-invivo-1', 4], # 259Hz, 0.5%, 2.0, ok
##['2020-10-27-ae-invivo-1', 4], # 375Hz, 0.5%, 4.3, nice, additional low freq line
###['2020-10-27-ag-invivo-1', 2], # 405Hz, 5%, 3.9, strong, is already the example
##['2021-08-03-ab-invivo-1', 1], # 140Hz, 0.5%, ok
['2020-10-29-ag-invivo-1', 2], # 164Hz, 5%, 1.6, no diagonal
##['2010-08-31-ag', 1], # 269Hz, 5%, no diagonal
['2018-08-24-ak', 1], # 145Hz, 5%, no diagonal
##['2018-08-29-af', 1], # 383Hz, 5%, no diagonal
]
for k, (cell, run) in enumerate(example_cells):
eodf, rate, cv, _, _, _, _ = load_baseline(results_path, cell)
fcutoff, contrast, freqs, gain, chi2 = load_spectra(results_path, cell, run)
dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff)
nli, nlif = peakedness(dfreqs, diag, rate, median=False)
print(f' {cell:<22s}: run={run:2d}, fbase={rate:3.0f}Hz, CV={cv:.2f}, SI={nli:3.1f}')
pc = plot_chi2(axs[k], s, contrast, freqs, chi2, fcutoff, 1.3)
axs[k].set_title(f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}', fontsize='medium')
axs[k].text(0.95, 0.9, f'SI($r$)={nli:.1f}', ha='right', zorder=50,
color='white', fontsize='medium',
transform=axs[k].transAxes)
plot_colorbar(axs[-1], pc)
fig.common_yticks(axs)
fig.tag(axs, xoffs=-3, yoffs=2)
fig.savefig()

349
regimes.py Normal file
View File

@ -0,0 +1,349 @@
import os
import numpy as np
from scipy.stats import linregress
import matplotlib.pyplot as plt
from numba import jit
from thunderlab.tabledata import TableData
from plotstyle import plot_style, lighter, darker
def load_models(file):
""" Load model parameter from csv file.
Parameters
----------
file: string
Name of file with model parameters.
Returns
-------
parameters: list of dict
For each cell a dictionary with model parameters.
"""
parameters = []
with open(file, 'r') as file:
header_line = file.readline()
header_parts = header_line.strip().split(",")
keys = header_parts
for line in file:
line_parts = line.strip().split(",")
parameter = {}
for i in range(len(keys)):
parameter[keys[i]] = float(line_parts[i]) if i > 0 else line_parts[i]
parameters.append(parameter)
return parameters
def cell_parameters(parameters, cell_name):
for params in parameters:
if params['cell'] == cell_name:
return params
print('cell', cell_name, 'not found!')
exit()
return None
@jit(nopython=True)
def simulate(stimulus, deltat=0.00005, v_zero=0.0, a_zero=2.0,
threshold=1.0, v_base=0.0, delta_a=0.08, tau_a=0.1,
v_offset=-10.0, mem_tau=0.015, noise_strength=0.05,
input_scaling=60.0, dend_tau=0.001, ref_period=0.001):
""" Simulate a P-unit.
Returns
-------
spike_times: 1-D array
Simulated spike times in seconds.
"""
# initial conditions:
v_dend = stimulus[0]
v_mem = v_zero
adapt = a_zero
# prepare noise:
noise = np.random.randn(len(stimulus))
noise *= noise_strength / np.sqrt(deltat)
# rectify stimulus array:
stimulus = stimulus.copy()
stimulus[stimulus < 0.0] = 0.0
# integrate:
spike_times = []
for i in range(len(stimulus)):
v_dend += (-v_dend + stimulus[i]) / dend_tau * deltat
v_mem += (v_base - v_mem + v_offset + (
v_dend * input_scaling) - adapt + noise[i]) / mem_tau * deltat
adapt += -adapt / tau_a * deltat
# refractory period:
if len(spike_times) > 0 and (deltat * i) - spike_times[-1] < ref_period + deltat/2:
v_mem = v_base
# threshold crossing:
if v_mem > threshold:
v_mem = v_base
spike_times.append(i * deltat)
adapt += delta_a / tau_a
return np.array(spike_times)
def punit_spikes(parameter, alpha, beatf1, beatf2, tmax, trials):
tini = 0.2
model_params = dict(parameter)
cell = model_params.pop('cell')
eodf0 = model_params.pop('EODf')
time = np.arange(-tini, tmax, model_params['deltat'])
stimulus = np.sin(2*np.pi*eodf0*time)
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf1)*time)
stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf2)*time)
spikes = []
for i in range(trials):
model_params['v_zero'] = np.random.rand()
model_params['a_zero'] += 0.02*parameter['a_zero']*np.random.randn()
spiket = simulate(stimulus, **model_params)
spikes.append(spiket[spiket > tini] - tini)
return spikes
def plot_am(ax, s, alpha, beatf1, beatf2, tmax):
time = np.arange(0, tmax, 0.0001)
am = alpha*np.sin(2*np.pi*beatf1*time)
am += alpha*np.sin(2*np.pi*beatf2*time)
ax.show_spines('l')
ax.plot(1000*time, -100*am, **s.lsStim)
ax.set_xlim(0, 1000*tmax)
ax.set_ylim(-50, 50)
#ax.set_xlabel('Time', 'ms')
ax.set_ylabel('AM', r'\%')
ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%',
transform=ax.transAxes, ha='right')
def plot_raster(ax, s, spikes, tmax):
spikes_ms = [1000*s[s<tmax] for s in spikes[:16]]
ax.show_spines('')
ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster)
ax.set_xlim(0, 1000*tmax)
#ax.set_xlabel('Time', 'ms')
#ax.set_ylabel('Trials')
def compute_power(spikes, nfft, dt):
psds = []
time = np.arange(nfft + 1)*dt
tmax = nfft*dt
rates = []
cvs = []
for s in spikes:
rates.append(len(s)/tmax)
isis = np.diff(s)
cvs.append(np.std(isis)/np.mean(isis))
b, _ = np.histogram(s, time)
fourier = np.fft.rfft(b - np.mean(b))
psds.append(np.abs(fourier)**2)
#psds.append(fourier)
freqs = np.fft.rfftfreq(nfft, dt)
#print('mean rate', np.mean(rates))
#print('CV', np.mean(cvs))
return freqs, np.mean(psds, 0)
#return freqs, np.abs(np.mean(psds, 0))**2/dt
def decibel(x):
return 10*np.log10(x/1e8)
def plot_psd(ax, s, spikes, nfft, dt, beatf1, beatf2):
offs = 3
freqs, psd = compute_power(spikes, nfft, dt)
psd /= freqs[1]
ax.plot(freqs, decibel(psd), **s.lsPower)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs,
label=r'$f_{\rm base}$', clip_on=False, **s.psF0)
ax.plot(beatf1, decibel(peak_ampl(freqs, psd, beatf1)) + offs,
label=r'$\Delta f_1$', clip_on=False, **s.psF01)
ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs + 4.5,
label=r'$\Delta f_2$', clip_on=False, **s.psF02)
ax.plot(beatf2 - beatf1, decibel(peak_ampl(freqs, psd, beatf2 - beatf1)) + offs,
label=r'$\Delta f_2 - \Delta f_1$', clip_on=False, **s.psF01_2)
ax.plot(beatf1 + beatf2, decibel(peak_ampl(freqs, psd, beatf1 + beatf2)) + offs,
label=r'$\Delta f_1 + \Delta f_2$', clip_on=False, **s.psF012)
ax.set_xlim(0, 300)
ax.set_ylim(-40, 0)
ax.set_xlabel('Frequency', 'Hz')
ax.set_ylabel('Power [dB]')
def plot_example(axs, axr, axp, s, cell, alpha, beatf1, beatf2, nfft, trials):
dt = 0.0001
tmax = nfft*dt
t1 = 0.1
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
plot_am(axs, s, alpha, beatf1, beatf2, t1)
plot_raster(axr, s, spikes, t1)
plot_psd(axp, s, spikes, nfft, dt, beatf1, beatf2)
def peak_ampl(freqs, psd, f):
df = 2
psd_snippet = psd[(freqs > f - df) & (freqs < f + df)]
return np.max(psd_snippet)
def compute_peaks(name, cell, alpha_max, beatf1, beatf2, nfft, trials):
file_name = f'{name}-contrastpeaks.csv'
if os.path.exists(file_name):
data = TableData(file_name)
return data
dt = 0.0001
tmax = nfft*dt
alphas = np.linspace(0, alpha_max, 200)
ampl_f1 = np.zeros(len(alphas))
ampl_f2 = np.zeros(len(alphas))
ampl_sum = np.zeros(len(alphas))
ampl_diff = np.zeros(len(alphas))
for k, alpha in enumerate(alphas):
print(alpha)
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
freqs, psd = compute_power(spikes, nfft, dt)
ampl_f1[k] = peak_ampl(freqs, psd, beatf1)
ampl_f2[k] = peak_ampl(freqs, psd, beatf2)
ampl_sum[k] = peak_ampl(freqs, psd, beatf1 + beatf2)
ampl_diff[k] = peak_ampl(freqs, psd, beatf2 - beatf1)
data = TableData()
data.append('contrast', '%', '%.1f', 100*alphas)
data.append('f1', 'Hz', '%g', ampl_f1)
data.append('f2', 'Hz', '%g', ampl_f2)
data.append('f1+f2', 'Hz', '%g', ampl_sum)
data.append('f2-f1', 'Hz', '%g', ampl_diff)
data.write(file_name)
return data
def amplitude(power):
power -= power[0]
power[power<0] = 0
return np.sqrt(power)
def amplitude_linearfit(contrast, power, max_contrast):
power -= power[0]
power[power<0] = 0
ampl = np.sqrt(power)
a = ampl[contrast <= max_contrast]
c = contrast[contrast <= max_contrast]
r = linregress(c, a)
return r.intercept + r.slope*contrast
def amplitude_squarefit(contrast, power, max_contrast):
power -= power[0]
power[power<0] = 0
ampl = np.sqrt(power)
a = np.sqrt(ampl[contrast <= max_contrast])
c = contrast[contrast <= max_contrast]
r = linregress(c, a)
return (r.intercept + r.slope*contrast)**2
def plot_peaks(ax, s, data, alphas):
contrast = data[:, 'contrast']
ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f1'], 4), **s.lsF01m)
ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f2'], 2), **s.lsF02m)
ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f1+f2'], 4), **s.lsF012m)
ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f2-f1'], 4), **s.lsF01_2m)
ax.plot(contrast, amplitude(data[:, 'f1']), **s.lsF01)
ax.plot(contrast, amplitude(data[:, 'f2']), **s.lsF02)
ax.plot(contrast, amplitude(data[:, 'f1+f2']), **s.lsF012)
ax.plot(contrast, amplitude(data[:, 'f2-f1']), **s.lsF01_2)
for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']):
contrast = 100*alpha
ax.plot(contrast, 630, 'vk', ms=4, clip_on=False)
ax.text(contrast, 660, tag, ha='center')
#ax.axvline(contrast, **s.lsGrid)
#ax.text(contrast, 630, tag, ha='center')
ax.axvline(1.5, **s.lsLine)
ax.axvline(4, **s.lsLine)
yoffs = 340
ax.text(1.5/2, yoffs, 'linear\nregime',
ha='center', va='center')
ax.text((1.5 + 4)/2, yoffs, 'weakly\nnonlinear\nregime',
ha='center', va='center')
ax.text(10, yoffs, 'strongly\nnonlinear\nregime',
ha='center', va='center')
ax.set_xlim(0, 16.5)
ax.set_ylim(0, 600)
ax.set_xticks_delta(5)
ax.set_yticks_delta(300)
ax.set_xlabel('Contrast', r'\%')
ax.set_ylabel('Amplitude', 'Hz')
if __name__ == '__main__':
parameters = load_models('models.csv')
cell_name = '2013-01-08-aa-invivo-1' # 138Hz, CV=0.26: perfect!
beatf1 = 40
beatf2 = 138
# cell_name = '2012-07-03-ak-invivo-1' # 128Hz, CV=0.24
# cell_name = '2018-05-08-ae-invivo-1' # 142Hz, CV=0.48
"""
parameters = load_models('models_big_fit_d_right.csv')
cell_name = '2013-01-08-aa-invivo-1' # 131Hz, CV=0.04: wrong!
beatf1 = 30
beatf2 = 132
"""
cell = cell_parameters(parameters, cell_name)
for k in cell:
print(k, cell[k])
s = plot_style()
s.lwmid = 1.0
s.lwthick = 1.6
s.lsStim = dict(color='gray', lw=s.lwmid)
s.lsRaster = dict(color='black', lw=s.lwthin)
s.lsPower = dict(color='gray', lw=s.lwmid)
s.lsF0 = dict(color='blue', lw=s.lwthick)
s.lsF01 = dict(color='green', lw=s.lwthick)
s.lsF02 = dict(color='purple', lw=s.lwthick)
s.lsF012 = dict(color='orange', lw=s.lwthick)
s.lsF01_2 = dict(color='red', lw=s.lwthick)
s.lsF0m = dict(color=lighter('blue', 0.5), lw=s.lwthin)
s.lsF01m = dict(color=lighter('green', 0.6), lw=s.lwthin)
s.lsF02m = dict(color=lighter('purple', 0.5), lw=s.lwthin)
s.lsF012m = dict(color=darker('orange', 0.9), lw=s.lwthin)
s.lsF01_2m = dict(color=darker('red', 0.9), lw=s.lwthin)
s.psF0 = dict(color='blue', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
s.psF01 = dict(color='green', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
s.psF02 = dict(color='purple', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
s.psF012 = dict(color='orange', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
s.psF01_2 = dict(color='red', marker='o', linestyle='none', markersize=5, mec='none', mew=0)
nfft = 2**18
fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.8*s.plot_width),
height_ratios=[1, 1.5, 2, 1.5, 4])
fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5,
wspace=0.3, hspace=0.3)
ax0 = fig.merge(axs[3, :])
ax0.set_visible(False)
axa = fig.merge(axs[4, :])
fig.show_spines('lb')
alphas = [0.01, 0.03, 0.05, 0.16]
#alphas = [0.002, 0.01, 0.05, 0.1]
for c, alpha in enumerate(alphas):
plot_example(axs[0, c], axs[1, c], axs[2, c], s, cell,
alpha, beatf1, beatf2, nfft, 100)
axs[1, 0].xscalebar(1, -0.1, 30, 'ms', ha='right')
axs[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8),
ncol=5, columnspacing=2)
data = compute_peaks(cell_name, cell, 0.2, beatf1, beatf2, nfft, 1000)
plot_peaks(axa, s, data, alphas)
fig.common_yspines(axs[0, :])
fig.common_yticks(axs[2, :])
#fig.common_xlabels(axs[2, :])
fig.tag(axs[0, :], xoffs=-2, yoffs=1.6)
fig.tag(axa)
fig.savefig()