changed example punit

This commit is contained in:
Jan Benda 2025-05-20 00:12:23 +02:00
parent 2e3d373100
commit ee2b8f98b7
18 changed files with 232 additions and 149 deletions

View File

@ -9,7 +9,7 @@ TXTFILE=$(TEXBASE).txt
PDFFIGURES=$(shell sed -n -e '/^[^%].*includegraphics/{s/^.*includegraphics.*{\([^}]*\)}.*/\1.pdf/;p}' $(TEXFILE)) PDFFIGURES=$(shell sed -n -e '/^[^%].*includegraphics/{s/^.*includegraphics.*{\([^}]*\)}.*/\1.pdf/;p}' $(TEXFILE))
PT=$(wildcard *.py) PT=$(wildcard *.py)
PYTHONFILES=$(filter-out plotstyle.py spectral.py, $(PT)) PYTHONFILES=$(filter-out plotstyle.py spectral.py examplecells.py, $(PT))
PYTHONPDFFILES=$(PYTHONFILES:.py=.pdf) PYTHONPDFFILES=$(PYTHONFILES:.py=.pdf)
REVISION=e3814a1be539f9424c17b7bd7ef45a8826a9f1e2 REVISION=e3814a1be539f9424c17b7bd7ef45a8826a9f1e2

View File

@ -7,9 +7,8 @@ from punitexamplecell import load_baseline, load_noise, load_spectra
from punitexamplecell import plot_colorbar from punitexamplecell import plot_colorbar
cell_name = '2012-05-15-ac' example_cell = [['2012-05-15-ac', 3],
run1 = 3 # 4 ['2012-05-15-ac', 1]]
run2 = 1
example_cells = [ example_cells = [
['2010-11-26-an', 0], ['2010-11-26-an', 0],
@ -37,9 +36,10 @@ def plot_isih2(ax, s, rate, cv, isis, pdf):
ax.show_spines('b') ax.show_spines('b')
ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1) ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1)
ax.set_xlim(0, 20) ax.set_xlim(0, 20)
ax.set_xticks_delta(5) #ax.set_xticks_delta(5)
ax.set_xticks_blank() #ax.set_xticks_blank()
#ax.set_xlabel('ISI', 'ms') #ax.set_xticks_fixed([0, 5, 10, 15, 20], ['0', '', '', '', '20\\,ms'])
ax.set_xticks_fixed([0, 5, 10, 15, 20], ['0', '5', '10', '15', '20\\,ms'])
ax.text(1, 1.1, f'CV$_{{\\rm base}}$={cv:.2f}', ha='right', ax.text(1, 1.1, f'CV$_{{\\rm base}}$={cv:.2f}', ha='right',
transform=ax.transAxes) transform=ax.transAxes)
ax.text(1, 0.6, f'$r={rate:.0f}$Hz', ha='right', transform=ax.transAxes) ax.text(1, 0.6, f'$r={rate:.0f}$Hz', ha='right', transform=ax.transAxes)
@ -154,26 +154,33 @@ def plot_diagonals(ax, s, fbase, contrast1, freqs1, chi21, contrast2, freqs2, ch
if __name__ == '__main__': if __name__ == '__main__':
""" """
from thunderlab.tabledata import TableData from thunderlab.tabledata import TableData
data = TableData('Apteronotus_leptorhynchus-Ampullary-data.csv') data = TableData('data/Apteronotus_leptorhynchus-Ampullary-data.csv')
data = data[(data('fcutoff') > 140) & (data('fcutoff') < 160), :] data = data[(data['fcutoff'] > 140) & (data['fcutoff'] < 160), :]
data = data[(data('nli') > 2) & (data('nli') < 2.5), :] data = data[(data['nli'] > 2) & (data['nli'] < 2.5), :]
data = data[(data('respmod2') > 20) & (data('respmod2') < 100), :] data = data[(data['respmod2'] > 20) & (data['respmod2'] < 100), :]
data = data[(data('cvbase') > 0.05) & (data('cvbase') < 0.2), :] data = data[(data['cvbase'] > 0.05) & (data['cvbase'] < 0.2), :]
data = data[(data('ratebase') > 100) & (data('ratebase') < 180), :] data = data[(data['ratebase'] > 100) & (data['ratebase'] < 180), :]
for k in range(data.rows()): for k in range(data.rows()):
print(f'{data[k, "cell"]:<22s} s{data[k, "stimindex"]:02.0f}: {100*data[k, "contrast"]:3g}%, {data[k, "respmod2"]:3.0f}Hz, nli={data[k, "nli"]:5.2f}') print(f'{data[k, "cell"]:<22s} s{data[k, "stimindex"]:02.0f}: '
f'{100*data[k, "contrast"]:3g}%, {data[k, "respmod2"]:3.0f}Hz, '
f'nli={data[k, "nli"]:5.2f}')
print() print()
#exit() exit()
""" """
cell_name = example_cell[0][0]
print('Example Ampullary cell:', cell_name) print('Example Ampullary cell:', cell_name)
eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(data_path, cell_name) eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(data_path, cell_name)
print(f' baseline firing rate: {rate:.0f}Hz') print(f' baseline firing rate: {rate:.0f}Hz')
print(f' baseline firing CV : {cv:.2f}') print(f' baseline firing CV : {cv:.2f}')
contrast1, time1, stimulus1, spikes1 = load_noise(data_path, cell_name, run1) contrast1, time1, stimulus1, spikes1 = load_noise(data_path,
contrast2, time2, stimulus2, spikes2 = load_noise(data_path, cell_name, run2) *example_cell[0])
fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(data_path, cell_name, run1) contrast2, time2, stimulus2, spikes2 = load_noise(data_path,
fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(data_path, cell_name, run2) *example_cell[1])
fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(data_path,
*example_cell[0])
fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(data_path,
*example_cell[1])
s = plot_style() s = plot_style()
s.cell_color1 = s.ampul_color1 s.cell_color1 = s.ampul_color1
@ -183,7 +190,7 @@ if __name__ == '__main__':
s.psC1 = s.psA1 s.psC1 = s.psA1
s.psC2 = s.psA2 s.psC2 = s.psA2
fig, (ax1, ax2, ax3) = \ fig, (ax1, ax2, ax3) = \
plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.2, 4.5], plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.2, 4.7],
cmsize=(s.plot_width, 0.85*s.plot_width)) cmsize=(s.plot_width, 0.85*s.plot_width))
fig.subplots_adjust(leftm=8, rightm=9, topm=2, bottomm=4, fig.subplots_adjust(leftm=8, rightm=9, topm=2, bottomm=4,
wspace=0.4, hspace=0.4) wspace=0.4, hspace=0.4)
@ -191,7 +198,7 @@ if __name__ == '__main__':
axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.4) axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.4)
axg = axg.subplots(1, 1, width_ratios=[1, 0.1]) axg = axg.subplots(1, 1, width_ratios=[1, 0.1])
axd = axd.subplots(1, 1, width_ratios=[0.2, 1]) axd = axd.subplots(1, 1, width_ratios=[0.2, 1])
axs = ax3.subplots(2, 4, wspace=0.4, hspace=0.2, height_ratios=[1, 4]) axs = ax3.subplots(2, 4, wspace=0.4, hspace=0.35, height_ratios=[1, 4])
plot_isih(axi, s, rate, cv, isis, pdf) plot_isih(axi, s, rate, cv, isis, pdf)
plot_response_spectrum(axp, s, eodf, rate, freqs, prr) plot_response_spectrum(axp, s, eodf, rate, freqs, prr)
@ -232,8 +239,8 @@ if __name__ == '__main__':
axs[0, 0].text(0, 1.6, 'Ampullary cells:', transform=axs[0, 0].transAxes, axs[0, 0].text(0, 1.6, 'Ampullary cells:', transform=axs[0, 0].transAxes,
color=s.cell_color1, color=s.cell_color1,
fontsize='large') fontsize='large')
axs[0, -1].text(0.97, -0.45, '5\\,ms', ha='right', #axs[0, -1].text(0.97, -0.45, '5\\,ms', ha='right',
transform=axs[0, -1].transAxes) # transform=axs[0, -1].transAxes)
plot_colorbar(axs[1, -1], pc, 0.4) plot_colorbar(axs[1, -1], pc, 0.4)
fig.common_yticks(axs[1, :]) fig.common_yticks(axs[1, :])
fig.tag([axs[0, :]], xoffs=-3, yoffs=1) fig.tag([axs[0, :]], xoffs=-3, yoffs=1)

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -10,41 +10,30 @@ from plotstyle import plot_style, lighter, significance_str
data_path = Path('data') data_path = Path('data')
punit_example = [['2020-10-27-ag-invivo-1', 0], from punitexamplecell import example_cell as punit_example
['2020-10-27-ag-invivo-1', 1]]
ampul_example = [['2012-05-15-ac', 3],
['2012-05-15-ac', 1]]
from punitexamplecell import example_cells as punit_examples from punitexamplecell import example_cells as punit_examples
from ampullaryexamplecell import example_cell as ampul_example
from ampullaryexamplecell import example_cells as ampul_examples from ampullaryexamplecell import example_cells as ampul_examples
def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color, def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color,
nli_thresh, example=[], examples=[]): nli_thresh, example=[], examples=[]):
ax.axhline(nli_thresh, color='k', ls=':', lw=0.5) ax.axhline(nli_thresh, color='k', ls=':', lw=0.5)
"""
for c in np.unique(data('cell')):
xdata = data[data('cell') == c, xcol]
ydata = data[data('cell') == c, ycol]
contrasts = data[data('cell') == c, 'contrast']
idx = np.argsort(contrasts)
if len(idx) > 1:
ax.plot(xdata[idx], ydata[idx], '-k', alpha=0.2, zorder=10)
"""
xmax = ax.get_xlim()[1] xmax = ax.get_xlim()[1]
ymax = ax.get_ylim()[1] ymax = ax.get_ylim()[1]
mask = (data(xcol) < xmax) & (data(ycol) < ymax) mask = (data[xcol] < xmax) & (data[ycol] < ymax)
sc = ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol], sc = ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
s=4, marker='o', linewidth=0, edgecolors='none', s=4, marker='o', linewidth=0, edgecolors='none',
clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20)
if 'stimindex' in data: if 'stimindex' in data:
for cell, run in example: for cell, run in example:
mask = (data('cell') == cell) & (data('stimindex') == run) mask = (data['cell'] == cell) & (data['stimindex'] == run)
ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol], ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
s=6, marker='^', linewidth=0.5, edgecolors='black', s=6, marker='^', linewidth=0.5, edgecolors='black',
clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
zorder=20) zorder=20)
for cell, run in examples: for cell, run in examples:
mask = (data('cell') == cell) & (data('stimindex') == run) mask = (data['cell'] == cell) & (data['stimindex'] == run)
ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol], ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
s=5, marker='o', linewidth=0.5, edgecolors='black', s=5, marker='o', linewidth=0.5, edgecolors='black',
clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
@ -57,7 +46,7 @@ def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color,
cb.outline.set_color('none') cb.outline.set_color('none')
cb.outline.set_linewidth(0) cb.outline.set_linewidth(0)
# pdf x-axis: # pdf x-axis:
kde = gaussian_kde(data(xcol), 0.02*xmax/np.std(data(xcol), ddof=1)) kde = gaussian_kde(data[xcol], 0.02*xmax/np.std(data[xcol], ddof=1))
xx = np.linspace(0, ax.get_xlim()[1], 400) xx = np.linspace(0, ax.get_xlim()[1], 400)
pdf = kde(xx) pdf = kde(xx)
xax = ax.inset_axes([0, 1.05, 1, 0.2]) xax = ax.inset_axes([0, 1.05, 1, 0.2])
@ -67,7 +56,7 @@ def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color,
xax.set_ylim(bottom=0) xax.set_ylim(bottom=0)
xax.set_ylim(0, xpdfmax) xax.set_ylim(0, xpdfmax)
# pdf y-axis: # pdf y-axis:
kde = gaussian_kde(data(ycol), 0.02*ymax/np.std(data(ycol), ddof=1)) kde = gaussian_kde(data[ycol], 0.02*ymax/np.std(data[ycol], ddof=1))
xx = np.linspace(0, ax.get_ylim()[1], 400) xx = np.linspace(0, ax.get_ylim()[1], 400)
pdf = kde(xx) pdf = kde(xx)
yax = ax.inset_axes([1.05, 0, 0.2, 1]) yax = ax.inset_axes([1.05, 0, 0.2, 1])
@ -77,12 +66,12 @@ def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color,
yax.set_xlim(left=0) yax.set_xlim(left=0)
# threshold: # threshold:
if 'cvbase' in xcol: if 'cvbase' in xcol:
ax.text(xmax, 0.4*ymax, f'{100*np.sum(data(ycol) > nli_thresh)/data.rows():.0f}\\%', ax.text(xmax, 0.4*ymax, f'{100*np.sum(data[ycol] > nli_thresh)/len(data):.0f}\\%',
ha='right', va='bottom', fontsize='small') ha='right', va='bottom', fontsize='small')
ax.text(xmax, 0.3, f'{100*np.sum(data(ycol) < nli_thresh)/data.rows():.0f}\\%', ax.text(xmax, 0.3, f'{100*np.sum(data[ycol] < nli_thresh)/len(data):.0f}\\%',
ha='right', va='center', fontsize='small') ha='right', va='center', fontsize='small')
# statistics: # statistics:
r, p = pearsonr(data(xcol), data(ycol)) r, p = pearsonr(data[xcol], data[ycol])
ax.text(1, 0.9, f'$R={r:.2f}$ **', ha='right', ax.text(1, 0.9, f'$R={r:.2f}$ **', ha='right',
transform=ax.transAxes, fontsize='small') transform=ax.transAxes, fontsize='small')
#ax.text(1, 0.77, f'{significance_str(p)}', ha='right', #ax.text(1, 0.77, f'{significance_str(p)}', ha='right',
@ -98,14 +87,14 @@ def nli_stats(title, data, column, nli_thresh):
print(title) print(title)
print(f' nli threshold: {nli_thresh:.1f}') print(f' nli threshold: {nli_thresh:.1f}')
nrecs = data.rows() nrecs = data.rows()
ncells = len(np.unique(data('cell'))) ncells = len(np.unique(data['cell']))
print(f' cells: {ncells}') print(f' cells: {ncells}')
print(f' recordings: {nrecs}') print(f' recordings: {nrecs}')
hcells = np.unique(data[data(column) > nli_thresh, 'cell']) hcells = np.unique(data[data(column) > nli_thresh, 'cell'])
print(f' high nli cells: n={len(hcells):3d}, {100*len(hcells)/ncells:4.1f}%') print(f' high nli cells: n={len(hcells):3d}, {100*len(hcells)/ncells:4.1f}%')
print(f' high nli recordings: n={np.sum(data(column) > nli_thresh):3d}, ' print(f' high nli recordings: n={np.sum(data(column) > nli_thresh):3d}, '
f'{100*np.sum(data(column) > nli_thresh)/nrecs:4.1f}%') f'{100*np.sum(data(column) > nli_thresh)/nrecs:4.1f}%')
nsegs = data('nsegs') nsegs = data['nsegs']
print(f' number of segments: {np.min(nsegs):4.0f} - {np.max(nsegs):4.0f}, median={np.median(nsegs):4.0f}, mean={np.mean(nsegs):4.0f}, std={np.std(nsegs):4.0f}') print(f' number of segments: {np.min(nsegs):4.0f} - {np.max(nsegs):4.0f}, median={np.median(nsegs):4.0f}, mean={np.mean(nsegs):4.0f}, std={np.std(nsegs):4.0f}')
@ -132,14 +121,18 @@ def plot_cvstim_nli_punit(ax, data, ycol, nli_thresh, color):
# 'coolwarm', color, nli_thresh, # 'coolwarm', color, nli_thresh,
# punit_example, punit_examples) # punit_example, punit_examples)
#cax.set_ylabel('Response mod.', 'Hz') #cax.set_ylabel('Response mod.', 'Hz')
#cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 1.5, 2, cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 1.5, 2,
# 'coolwarm', color, nli_thresh,
# punit_example, punit_examples)
#cax.set_ylabel('CV$_{\\rm base}$')
cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 50, 450, 2,
'coolwarm', color, nli_thresh, 'coolwarm', color, nli_thresh,
punit_example, punit_examples) punit_example, punit_examples)
cax.set_ylabel('$r$', 'Hz') cax.set_ylabel('CV$_{\\rm base}$')
#cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 50, 450, 2,
# 'coolwarm', color, nli_thresh,
# punit_example, punit_examples)
#cax.set_ylabel('$r$', 'Hz')
#cax = plot_corr(ax, data, 'cvstim', ycol, 'serialcorr1', -0.6, 0, 2,
# 'coolwarm', color, nli_thresh,
# punit_example, punit_examples)
#cax.set_ylabel('$\\rho_1$')
def plot_mod_nli_punit(ax, data, ycol, nli_thresh, color): def plot_mod_nli_punit(ax, data, ycol, nli_thresh, color):
@ -178,16 +171,16 @@ def plot_cvstim_nli_ampul(ax, data, ycol, nli_thresh, color):
# 'coolwarm', color, nli_thresh, # 'coolwarm', color, nli_thresh,
# ampul_example, ampul_examples) # ampul_example, ampul_examples)
#cax.set_ylabel('Response mod.', 'Hz') #cax.set_ylabel('Response mod.', 'Hz')
#cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 0.2, 6, cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 0.2, 6,
# 'coolwarm', color, nli_thresh,
# ampul_example, ampul_examples)
#cax.set_ylabel('CV$_{\\rm base}$')
#cax.set_yticks_delta(0.1)
cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 90, 180, 6,
'coolwarm', color, nli_thresh, 'coolwarm', color, nli_thresh,
ampul_example, ampul_examples) ampul_example, ampul_examples)
cax.set_ylabel('$r$', 'Hz') cax.set_ylabel('CV$_{\\rm base}$')
cax.set_yticks_delta(30) cax.set_yticks_delta(0.1)
#cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 90, 180, 6,
# 'coolwarm', color, nli_thresh,
# ampul_example, ampul_examples)
#cax.set_ylabel('$r$', 'Hz')
#cax.set_yticks_delta(30)
def plot_mod_nli_ampul(ax, data, ycol, nli_thresh, color): def plot_mod_nli_ampul(ax, data, ycol, nli_thresh, color):
@ -208,7 +201,7 @@ if __name__ == '__main__':
punit_model = TableData(data_path / punit_model = TableData(data_path /
'Apteronotus_leptorhynchus-Punit-models.csv', 'Apteronotus_leptorhynchus-Punit-models.csv',
sep=';') sep=';')
punit_model = punit_model[punit_model('contrast') > 1e-6, :] punit_model = punit_model[punit_model['contrast'] > 1e-6, :]
punit_data = TableData(data_path / punit_data = TableData(data_path /
'Apteronotus_leptorhynchus-Punit-data.csv', 'Apteronotus_leptorhynchus-Punit-data.csv',
sep=';') sep=';')
@ -217,23 +210,23 @@ if __name__ == '__main__':
sep=';') sep=';')
nli_thresh = 1.8 nli_thresh = 1.8
u, p = mannwhitneyu(punit_model('cvbase'), punit_data('cvbase')) u, p = mannwhitneyu(punit_model['cvbase'], punit_data['cvbase'])
print('CV differs between P-unit models and data:') print('CV differs between P-unit models and data:')
print(f' U={u:g}, p={p:g}') print(f' U={u:g}, p={p:.2g}')
print(f' median model: {np.median(punit_model("cvbase")):.2f}') print(f' median model: {np.median(punit_model["cvbase"]):.2f}')
print(f' median data: {np.median(punit_data("cvbase")):.2f}') print(f' median data: {np.median(punit_data["cvbase"]):.2f}')
print() print()
u, p = mannwhitneyu(punit_model('respmod2'), punit_data('respmod2')) u, p = mannwhitneyu(punit_model['respmod2'], punit_data['respmod2'])
print('Response modulation differs between P-unit models and data:') print('Response modulation differs between P-unit models and data:')
print(f' U={u:g}, p={p:g}') print(f' U={u:g}, p={p:.2g}')
print(f' median model: {np.median(punit_model("respmod2")):.2f}') print(f' median model: {np.median(punit_model["respmod2"]):.2f}')
print(f' median data: {np.median(punit_data("respmod2")):.2f}') print(f' median data: {np.median(punit_data["respmod2"]):.2f}')
print() print()
u, p = mannwhitneyu(punit_model('dnli100'), punit_data('nli')) u, p = mannwhitneyu(punit_model['dnli100'], punit_data['nli'])
print('NLI does not differ between P-unit models and data:') print('NLI does not differ between P-unit models and data:')
print(f' U={u:g}, p={p:g}') print(f' U={u:g}, p={p:.2g}')
print(f' median model: {np.median(punit_model("dnli100")):.1f}') print(f' median model: {np.median(punit_model["dnli100"]):.1f}')
print(f' median data: {np.median(punit_data("nli")):.1f}') print(f' median data: {np.median(punit_data["nli"]):.1f}')
print() print()
s = plot_style() s = plot_style()

46
examplecells.py Normal file
View File

@ -0,0 +1,46 @@
import numpy as np
from pathlib import Path
from importlib import import_module
exclude = ['examplecells.py', 'plotstyle.py', 'spectral.py']
data_cells = []
model_cells = []
for pf in sorted(Path('.').glob('*.py'), key=lambda x: x.stem):
if pf.name in exclude:
continue
print(pf.name)
figure = import_module(pf.stem)
if hasattr(figure, 'example_cell'):
name = figure.example_cell
while isinstance(name, list):
name = name[0]
print(f' found example_cell: {name}')
data_cells.append(name)
if hasattr(figure, 'example_cells'):
for name in figure.example_cells:
while isinstance(name, list):
name = name[0]
print(f' found example_cells: {name}')
data_cells.append(name)
if hasattr(figure, 'model_cell'):
name = figure.model_cell
while isinstance(name, list):
name = name[0]
print(f' found model_cell: {name}')
model_cells.append(name)
if hasattr(figure, 'model_cells'):
for name in figure.model_cells:
while isinstance(name, list):
name = name[0]
print(f' found model_cells: {name}')
model_cells.append(name)
print()
print('The following cell data are used in the plots:')
for cell in np.unique(data_cells):
print(cell)
print()
print('The following cell models are used in the plots:')
for cell in np.unique(model_cells):
print(cell)

View File

@ -6,6 +6,11 @@ from pathlib import Path
from plotstyle import plot_style, labels_params, significance_str from plotstyle import plot_style, labels_params, significance_str
model_cells = ['2017-07-18-ai-invivo-1', # strong triangle
'2012-12-13-ao-invivo-1', # triangle
'2012-12-20-ac-invivo-1', # weak border triangle
'2013-01-08-ab-invivo-1'] # no triangle
data_path = Path('data') data_path = Path('data')
sims_path = data_path / 'simulations' sims_path = data_path / 'simulations'
@ -70,7 +75,7 @@ def plot_chi2(ax, s, data_file):
def plot_chi2_contrasts(axs, s, cell_name): def plot_chi2_contrasts(axs, s, cell_name):
print(cell_name) print(f' {cell_name}')
files, nums = sort_files(cell_name, files, nums = sort_files(cell_name,
sims_path.glob(f'chi2-split-{cell_name}-*.npz'), 1) sims_path.glob(f'chi2-split-{cell_name}-*.npz'), 1)
plot_chi2(axs[0], s, files[-1]) plot_chi2(axs[0], s, files[-1])
@ -81,21 +86,21 @@ def plot_chi2_contrasts(axs, s, cell_name):
def plot_nli_cv(ax, s, data, alpha, cells): def plot_nli_cv(ax, s, data, alpha, cells):
data = data[data('contrast') == alpha, :] data = data[data['contrast'] == alpha, :]
r, p = pearsonr(data('cvbase'), data[:, 'dnli']) r, p = pearsonr(data['cvbase'], data['dnli'])
l = linregress(data('cvbase'), data[:, 'dnli']) l = linregress(data['cvbase'], data['dnli'])
x = np.linspace(0, 1, 10) x = np.linspace(0, 1, 10)
ax.set_visible(True) ax.set_visible(True)
ax.set_title(f'$c$={100*alpha:g}\\,\\%', fontsize='medium') ax.set_title(f'$c$={100*alpha:g}\\,\\%', fontsize='medium')
ax.axhline(1, **s.lsLine) ax.axhline(1, **s.lsLine)
ax.plot(x, l.slope*x + l.intercept, **s.lsGrid) ax.plot(x, l.slope*x + l.intercept, **s.lsGrid)
mask = data('triangle') > 0.5 mask = data['triangle'] > 0.5
ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'], ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'],
clip_on=False, zorder=30, label='strong', **s.psA1m) clip_on=False, zorder=30, label='strong', **s.psA1m)
mask = data[:, 'border'] > 0.5 mask = data['border'] > 0.5
ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'], ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'],
zorder=20, label='weak', **s.psA2m) zorder=20, label='weak', **s.psA2m)
ax.plot(data[:, 'cvbase'], data[:, 'dnli'], clip_on=False, ax.plot(data['cvbase'], data['dnli'], clip_on=False,
zorder=10, label='none', **s.psB1m) zorder=10, label='none', **s.psB1m)
for cell_name in cells: for cell_name in cells:
@ -124,7 +129,7 @@ def plot_nli_cv(ax, s, data, alpha, cells):
title='triangle', handlelength=0.5, title='triangle', handlelength=0.5,
handletextpad=0.5, labelspacing=0.2) handletextpad=0.5, labelspacing=0.2)
kde = gaussian_kde(data('dnli'), 0.15/np.std(data('dnli'), ddof=1)) kde = gaussian_kde(data['dnli'], 0.15/np.std(data['dnli'], ddof=1))
nli = np.linspace(0, 8, 100) nli = np.linspace(0, 8, 100)
pdf = kde(nli) pdf = kde(nli)
dax = ax.inset_axes([1.04, 0, 0.3, 1]) dax = ax.inset_axes([1.04, 0, 0.3, 1])
@ -137,34 +142,35 @@ def plot_summary_contrasts(axs, s, cells):
nli_thresh = 1.2 nli_thresh = 1.2
data = TableData(data_path / 'Apteronotus_leptorhynchus-Punit-models.csv') data = TableData(data_path / 'Apteronotus_leptorhynchus-Punit-models.csv')
plot_nli_cv(axs[0], s, data, 0, cells) plot_nli_cv(axs[0], s, data, 0, cells)
print('split:') print('noise split:')
nli_split = data[data('contrast') == 0, 'dnli'] cdata = data[data['contrast'] == 0, :]
print(f' mean NLI = {np.mean(nli_split):.2f}, stdev = {np.std(nli_split):.2f}') nli_split = cdata['dnli']
print(f' mean SI = {np.mean(nli_split):.2f}, stdev = {np.std(nli_split):.2f}')
n = np.sum(nli_split > nli_thresh) n = np.sum(nli_split > nli_thresh)
print(f' {n} cells ({100*n/len(nli_split):.1f}%) have NLI > {nli_thresh:.1f}') print(f' {n} cells ({100*n/len(nli_split):.1f}%) have SI > {nli_thresh:.1f}:')
print(f' triangle cells have nli >= {np.min(nli_split[data[data("contrast") == 0, "triangle"] > 0.5])}') for name, cv in cdata[nli_split > nli_thresh, ['cell', 'cvbase']].row_data():
print(f' {name:<22} CV={cv:4.2f}')
print(f' triangle cells have SI >= {np.min(nli_split[cdata["triangle"] > 0.5]):.2f}')
print() print()
for i, a in enumerate([0.01, 0.03, 0.1]): for i, a in enumerate([0.01, 0.03, 0.1]):
plot_nli_cv(axs[1 + i], s, data, a, cells) plot_nli_cv(axs[1 + i], s, data, a, cells)
print(f'contrast {100*a:2g}%:') print(f'contrast {100*a:2g}%:')
cdata = data[data('contrast') == a, :] cdata = data[data['contrast'] == a, :]
nli = cdata('dnli') nli = cdata['dnli']
r, p = pearsonr(nli_split, nli) r, p = pearsonr(nli_split, nli)
print(f' correlation with split: r={r:.2f}, p={p:.1e}') print(f' correlation with split: r={r:.2f}, p={p:.1e}')
print(f' mean NLI = {np.mean(nli):.2f}, stdev = {np.std(nli):.2f}') print(f' mean SI = {np.mean(nli):.2f}, stdev = {np.std(nli):.2f}')
n = np.sum(nli > nli_thresh) n = np.sum(nli > nli_thresh)
print(f' {n} cells ({100*n/len(nli):.1f}%) have NLI > {nli_thresh:.1f}') print(f' {n} cells ({100*n/len(nli):.1f}%) have SI > {nli_thresh:.1f}:')
print( ' CVs:', cdata[nli > nli_thresh, 'cvbase']) for name, cv in cdata[nli > nli_thresh, ['cell', 'cvbase']].row_data():
print( ' names:', cdata[nli > nli_thresh, 'cell']) print(f' {name:<22} CV={cv:4.2f}')
print(f' triangle cells have SI >= {np.min(nli[cdata["triangle"] > 0.5]):.2f}')
print() print()
print('lowest baseline CV:', np.unique(data('cvbase'))[:3]) print('overall lowest baseline CV:',
' '.join([f'{cv:.2f}' for cv in np.unique(data['cvbase'])[:5]]))
if __name__ == '__main__': if __name__ == '__main__':
cells = ['2017-07-18-ai-invivo-1', # strong triangle
'2012-12-13-ao-invivo-1', # triangle
'2012-12-20-ac-invivo-1', # weak border triangle
'2013-01-08-ab-invivo-1'] # no triangle
s = plot_style() s = plot_style()
#labels_params(xlabelloc='right', ylabelloc='top') #labels_params(xlabelloc='right', ylabelloc='top')
fig, axs = plt.subplots(6, 4, cmsize=(s.plot_width, 0.95*s.plot_width), fig, axs = plt.subplots(6, 4, cmsize=(s.plot_width, 0.95*s.plot_width),
@ -173,12 +179,14 @@ if __name__ == '__main__':
wspace=1, hspace=0.7) wspace=1, hspace=0.7)
for ax in axs.flat: for ax in axs.flat:
ax.set_visible(False) ax.set_visible(False)
for k in range(len(cells)): print('Example cells:')
plot_chi2_contrasts(axs[k], s, cells[k]) for k in range(len(model_cells)):
plot_chi2_contrasts(axs[k], s, model_cells[k])
for k in range(4): for k in range(4):
fig.common_yticks(axs[k, :]) fig.common_yticks(axs[k, :])
fig.common_xticks(axs[:4, k]) fig.common_xticks(axs[:4, k])
plot_summary_contrasts(axs[5], s, cells) print()
plot_summary_contrasts(axs[5], s, model_cells)
fig.common_yticks(axs[5, :]) fig.common_yticks(axs[5, :])
fig.tag(axs, xoffs=-4.5, yoffs=1.8) fig.tag(axs, xoffs=-4.5, yoffs=1.8)
fig.savefig() fig.savefig()

View File

@ -6,6 +6,8 @@ from pathlib import Path
from plotstyle import plot_style, labels_params, significance_str from plotstyle import plot_style, labels_params, significance_str
model_cell = '2012-12-21-ak-invivo-1'
data_path = Path('data') data_path = Path('data')
sims_path = data_path / 'simulations' sims_path = data_path / 'simulations'
@ -71,7 +73,7 @@ def plot_chi2(ax, s, data_file):
def plot_chi2_contrasts(axs, s, cell_name, n=None): def plot_chi2_contrasts(axs, s, cell_name, n=None):
print(cell_name) print(f' {cell_name}')
files, nums = sort_files(cell_name, files, nums = sort_files(cell_name,
sims_path.glob(f'chi2-split-{cell_name}-*.npz'), 1) sims_path.glob(f'chi2-split-{cell_name}-*.npz'), 1)
idx = -1 if n is None else nums.index(n) idx = -1 if n is None else nums.index(n)
@ -84,10 +86,10 @@ def plot_chi2_contrasts(axs, s, cell_name, n=None):
def plot_nli_diags(ax, s, data, alphax, alphay, xthresh, ythresh, cell_name): def plot_nli_diags(ax, s, data, alphax, alphay, xthresh, ythresh, cell_name):
datax = data[data('contrast') == alphax, :] datax = data[data['contrast'] == alphax, :]
datay = data[data('contrast') == alphay, :] datay = data[data['contrast'] == alphay, :]
nlix = datax('dnli') nlix = datax['dnli']
nliy = datay('dnli100') nliy = datay['dnli100']
nfp = np.sum((nliy > ythresh) & (nlix < xthresh)) nfp = np.sum((nliy > ythresh) & (nlix < xthresh))
ntp = np.sum((nliy > ythresh) & (nlix > xthresh)) ntp = np.sum((nliy > ythresh) & (nlix > xthresh))
ntn = np.sum((nliy < ythresh) & (nlix < xthresh)) ntn = np.sum((nliy < ythresh) & (nlix < xthresh))
@ -106,13 +108,13 @@ def plot_nli_diags(ax, s, data, alphax, alphay, xthresh, ythresh, cell_name):
ax.axhline(ythresh, **s.lsLine) ax.axhline(ythresh, **s.lsLine)
ax.axvline(xthresh, 0, 0.5, **s.lsLine) ax.axvline(xthresh, 0, 0.5, **s.lsLine)
if alphax == 0: if alphax == 0:
mask = datax('triangle') > 0.5 mask = datax['triangle'] > 0.5
ax.plot(nlix[mask], nliy[mask], zorder=30, label='strong', **s.psA1m) ax.plot(nlix[mask], nliy[mask], zorder=30, label='strong', **s.psA1m)
mask = datax('border') > 0.5 mask = datax['border'] > 0.5
ax.plot(nliy[mask], nliy[mask], zorder=20, label='weak', **s.psA2m) ax.plot(nliy[mask], nliy[mask], zorder=20, label='weak', **s.psA2m)
ax.plot(nlix, nliy, zorder=10, label='none', **s.psB1m) ax.plot(nlix, nliy, zorder=10, label='none', **s.psB1m)
# mark cell: # mark cell:
mask = datax('cell') == cell_name mask = datax['cell'] == cell_name
color = s.psB1m['color'] color = s.psB1m['color']
if alphax == 0: if alphax == 0:
if datax[mask, 'border']: if datax[mask, 'border']:
@ -183,15 +185,16 @@ if __name__ == '__main__':
wspace=1, hspace=1) wspace=1, hspace=1)
for ax in axs.flat: for ax in axs.flat:
ax.set_visible(False) ax.set_visible(False)
cell_name = '2012-12-21-ak-invivo-1' print('Example cells:')
plot_chi2_contrasts(axs[0], s, cell_name) plot_chi2_contrasts(axs[0], s, model_cell)
plot_chi2_contrasts(axs[1], s, cell_name, 10) plot_chi2_contrasts(axs[1], s, model_cell, 10)
for k in range(2): for k in range(2):
fig.common_yticks(axs[k, :]) fig.common_yticks(axs[k, :])
for k in range(4): for k in range(4):
fig.common_xticks(axs[:2, k]) fig.common_xticks(axs[:2, k])
plot_summary_contrasts(axs[3], s, xthresh, ythresh, cell_name) print()
plot_summary_diags(axs[5], s, xthresh, ythresh, cell_name) plot_summary_contrasts(axs[3], s, xthresh, ythresh, model_cell)
plot_summary_diags(axs[5], s, xthresh, ythresh, model_cell)
fig.common_yticks(axs[3, 1:]) fig.common_yticks(axs[3, 1:])
fig.common_yticks(axs[5, 1:]) fig.common_yticks(axs[5, 1:])
fig.tag(axs, xoffs=-4.5, yoffs=1.8) fig.tag(axs, xoffs=-4.5, yoffs=1.8)

View File

@ -5,6 +5,10 @@ from spectral import whitenoise
from plotstyle import plot_style from plotstyle import plot_style
#example_cell = ['2012-07-03-ak-invivo-1', 0]
example_cell = ['2017-07-18-ai-invivo-1', 1] # Take this! at 3% model, 5% data
model_cell = example_cell
base_path = Path('data') base_path = Path('data')
data_path = base_path / 'cells' data_path = base_path / 'cells'
sims_path = base_path / 'simulations' sims_path = base_path / 'simulations'
@ -225,8 +229,6 @@ def plot_noise_split(ax, contrast, noise_contrast, noise_frac,
if __name__ == '__main__': if __name__ == '__main__':
#cell_name = ['2012-07-03-ak-invivo-1', 0]
cell_name = ['2017-07-18-ai-invivo-1', 1] # Take this! at 3% model, 5% data
nsmall = 100 nsmall = 100
nlarge = 1000000 nlarge = 1000000
contrast = 0.01 contrast = 0.01
@ -249,16 +251,16 @@ if __name__ == '__main__':
axss = axs[0] axss = axs[0]
axss[1].text(xt, yt, 'P-unit data', fontsize='large', axss[1].text(xt, yt, 'P-unit data', fontsize='large',
transform=axss[1].transAxes, color=s.punit_color1) transform=axss[1].transAxes, color=s.punit_color1)
data_contrast, ratebase, eodf = plot_chi2_data(axss[1], s, cell_name[0], data_contrast, ratebase, eodf = plot_chi2_data(axss[1], s, example_cell[0],
cell_name[1]) example_cell[1])
plot_ram(axss[0], data_contrast, eodf, wtime, wnoise) plot_ram(axss[0], data_contrast, eodf, wtime, wnoise)
axss[1].text(xt + 0.9, yt, f'$r={ratebase:.0f}$\\,Hz', axss[1].text(xt + 0.9, yt, f'$r={ratebase:.0f}$\\,Hz',
transform=axss[1].transAxes, fontsize='large') transform=axss[1].transAxes, fontsize='large')
# model 5%: # model 5%:
axss = axs[1] axss = axs[1]
data_files = sims_path.glob(f'chi2-noisen-{cell_name[0]}-{1000*data_contrast:03.0f}-*.npz') data_files = sims_path.glob(f'chi2-noisen-{example_cell[0]}-{1000*data_contrast:03.0f}-*.npz')
files, nums = sort_files(cell_name[0], data_files, 2) files, nums = sort_files(example_cell[0], data_files, 2)
axss[1].text(xt, yt, 'P-unit model', fontsize='large', axss[1].text(xt, yt, 'P-unit model', fontsize='large',
transform=axs[1, 1].transAxes, color=s.model_color1) transform=axs[1, 1].transAxes, color=s.model_color1)
plot_chi2_contrast(axss[1], axss[2], s, files, nums, nsmall, nlarge) plot_chi2_contrast(axss[1], axss[2], s, files, nums, nsmall, nlarge)
@ -269,16 +271,16 @@ if __name__ == '__main__':
# model 1%: # model 1%:
axss = axs[2] axss = axs[2]
data_files = sims_path.glob(f'chi2-noisen-{cell_name[0]}-{1000*contrast:03.0f}-*.npz') data_files = sims_path.glob(f'chi2-noisen-{example_cell[0]}-{1000*contrast:03.0f}-*.npz')
files, nums = sort_files(cell_name[0], data_files, 2) files, nums = sort_files(example_cell[0], data_files, 2)
plot_chi2_contrast(axss[1], axss[2], s, files, nums, nsmall, nlarge) plot_chi2_contrast(axss[1], axss[2], s, files, nums, nsmall, nlarge)
axr2 = plot_noise_split(axss[0], contrast, 0, 1, wtime, wnoise) axr2 = plot_noise_split(axss[0], contrast, 0, 1, wtime, wnoise)
plot_overn(axss[3], s, files, nmax=1e6) plot_overn(axss[3], s, files, nmax=1e6)
# model noise split: # model noise split:
axss = axs[3] axss = axs[3]
data_files = sims_path.glob(f'chi2-split-{cell_name[0]}-*.npz') data_files = sims_path.glob(f'chi2-split-{example_cell[0]}-*.npz')
files, nums = sort_files(cell_name[0], data_files, 1) files, nums = sort_files(example_cell[0], data_files, 1)
axss[1].text(xt, yt, 'P-unit model', fontsize='large', axss[1].text(xt, yt, 'P-unit model', fontsize='large',
transform=axss[1].transAxes, color=s.model_color1) transform=axss[1].transAxes, color=s.model_color1)
axss[1].text(xt + 0.9, yt, f'(noise split)', fontsize='large', axss[1].text(xt + 0.9, yt, f'(noise split)', fontsize='large',

View File

@ -7,9 +7,8 @@ from spectral import diag_projection, peakedness
from plotstyle import plot_style from plotstyle import plot_style
cell_name = '2020-10-27-ag-invivo-1' example_cell = [['2020-10-27-ag-invivo-1', 0],
run1 = 0 ['2020-10-27-ag-invivo-1', 1]]
run2 = 1
example_cells = [ example_cells = [
['2021-06-18-ae-invivo-1', 3], # 98Hz, 1%, ok ['2021-06-18-ae-invivo-1', 3], # 98Hz, 1%, ok
@ -22,10 +21,12 @@ example_cells = [
##['2020-10-27-ae-invivo-1', 4], # 375Hz, 0.5%, 4.3, nice, additional low freq line ##['2020-10-27-ae-invivo-1', 4], # 375Hz, 0.5%, 4.3, nice, additional low freq line
###['2020-10-27-ag-invivo-1', 2], # 405Hz, 5%, 3.9, strong, is already the example ###['2020-10-27-ag-invivo-1', 2], # 405Hz, 5%, 3.9, strong, is already the example
##['2021-08-03-ab-invivo-1', 1], # 140Hz, 0.5%, ok ##['2021-08-03-ab-invivo-1', 1], # 140Hz, 0.5%, ok
['2020-10-29-ag-invivo-1', 2], # 164Hz, 5%, 1.6, no diagonal #['2020-10-29-ag-invivo-1', 2], # 164Hz, 5%, 1.6, no diagonal
##['2010-08-31-ag', 1], # 269Hz, 5%, no diagonal
['2018-08-24-ak', 1], # 145Hz, 5%, no diagonal ['2018-08-24-ak', 1], # 145Hz, 5%, no diagonal
['2018-08-14-ac', 0], # 239Hz, 10%, no diagonal
##['2010-08-31-ag', 1], # 269Hz, 5%, no diagonal
##['2018-08-29-af', 1], # 383Hz, 5%, no diagonal ##['2018-08-29-af', 1], # 383Hz, 5%, no diagonal
] ]
data_path = Path('data') / 'cells' data_path = Path('data') / 'cells'
@ -95,9 +96,10 @@ def plot_isih2(ax, s, rate, cv, isis, pdf):
ax.show_spines('b') ax.show_spines('b')
ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1) ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1)
ax.set_xlim(0, 20) ax.set_xlim(0, 20)
ax.set_xticks_delta(5) #ax.set_xticks_delta(5)
ax.set_xticks_blank() #ax.set_xticks_blank()
#ax.set_xlabel('ISI', 'ms') #ax.set_xticks_fixed([0, 5, 10, 15, 20], ['0', '', '', '', '20\\,ms'])
ax.set_xticks_fixed([0, 5, 10, 15, 20], ['0', '5', '10', '15', '20\\,ms'])
ax.text(1, 1.1, f'CV$_{{\\rm base}}$={cv:.2f}', ha='right', ax.text(1, 1.1, f'CV$_{{\\rm base}}$={cv:.2f}', ha='right',
transform=ax.transAxes) transform=ax.transAxes)
ax.text(1, 0.6, f'$r={rate:.0f}$Hz', ha='right', transform=ax.transAxes) ax.text(1, 0.6, f'$r={rate:.0f}$Hz', ha='right', transform=ax.transAxes)
@ -218,14 +220,36 @@ def plot_diagonals(ax, s, fbase, contrast1, freqs1, chi21, contrast2, freqs2, ch
if __name__ == '__main__': if __name__ == '__main__':
"""
from thunderlab.tabledata import TableData
data = TableData('data/Apteronotus_leptorhynchus-Punit-data.csv')
data = data[(data['nli'] > 0) & (data['nli'] <= 1.2), :]
data = data[(data['respmod2'] > 150) & (data['respmod2'] < 200), :]
data = data[(data['cvbase'] > 0.4) & (data['cvbase'] < 0.8), :]
data = data[(data['ratebase'] > 300) & (data['ratebase'] < 400), :]
for k in range(data.rows()):
print(f'{data[k, "cell"]:<22s} s{data[k, "stimindex"]:02.0f}: '
f'{100*data[k, "contrast"]:3g}%, r={data[k, "ratebase"]:3.0f}Hz, '
f'CV={data[k, "cvbase"]:4.2f}, '
f'rmod={data[k, "respmod2"]:3.0f}Hz, '
f'nli={data[k, "nli"]:5.2f}')
print()
#exit()
"""
cell_name = example_cell[0][0]
print('Example P-unit:', cell_name) print('Example P-unit:', cell_name)
eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(data_path, cell_name) eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(data_path, cell_name)
print(f' baseline firing rate: {rate:.0f}Hz') print(f' baseline firing rate: {rate:.0f}Hz')
print(f' baseline firing CV : {cv:.2f}') print(f' baseline firing CV : {cv:.2f}')
contrast1, time1, stimulus1, spikes1 = load_noise(data_path, cell_name, run1) contrast1, time1, stimulus1, spikes1 = load_noise(data_path,
contrast2, time2, stimulus2, spikes2 = load_noise(data_path, cell_name, run2) *example_cell[0])
fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(data_path, cell_name, run1) contrast2, time2, stimulus2, spikes2 = load_noise(data_path,
fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(data_path, cell_name, run2) *example_cell[1])
fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(data_path,
*example_cell[0])
fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(data_path,
*example_cell[1])
s = plot_style() s = plot_style()
s.cell_color1 = s.punit_color1 s.cell_color1 = s.punit_color1
@ -235,7 +259,7 @@ if __name__ == '__main__':
s.psC1 = s.psP1 s.psC1 = s.psP1
s.psC2 = s.psP2 s.psC2 = s.psP2
fig, (ax1, ax2, ax3) = \ fig, (ax1, ax2, ax3) = \
plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.2, 4.5], plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.2, 4.7],
cmsize=(s.plot_width, 0.85*s.plot_width)) cmsize=(s.plot_width, 0.85*s.plot_width))
fig.subplots_adjust(leftm=8, rightm=9, topm=2, bottomm=4, fig.subplots_adjust(leftm=8, rightm=9, topm=2, bottomm=4,
wspace=0.4, hspace=0.4) wspace=0.4, hspace=0.4)
@ -243,7 +267,7 @@ if __name__ == '__main__':
axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.4) axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.4)
axg = axg.subplots(1, 1, width_ratios=[1, 0.1]) axg = axg.subplots(1, 1, width_ratios=[1, 0.1])
axd = axd.subplots(1, 1, width_ratios=[0.2, 1]) axd = axd.subplots(1, 1, width_ratios=[0.2, 1])
axs = ax3.subplots(2, 4, wspace=0.4, hspace=0.2, height_ratios=[1, 4]) axs = ax3.subplots(2, 4, wspace=0.4, hspace=0.35, height_ratios=[1, 4])
plot_isih(axi, s, rate, cv, isis, pdf) plot_isih(axi, s, rate, cv, isis, pdf)
plot_response_spectrum(axp, s, eodf, rate, freqs, prr) plot_response_spectrum(axp, s, eodf, rate, freqs, prr)
@ -276,16 +300,15 @@ if __name__ == '__main__':
nli, nlif = peakedness(dfreqs, diag, rate, median=False) nli, nlif = peakedness(dfreqs, diag, rate, median=False)
print(f' {cell:<22s}: run={run:2d}, fbase={rate:3.0f}Hz, CV={cv:.2f}, SI={nli:3.1f}') print(f' {cell:<22s}: run={run:2d}, fbase={rate:3.0f}Hz, CV={cv:.2f}, SI={nli:3.1f}')
plot_isih2(axs[0, k], s, rate, cv, isis, pdf) plot_isih2(axs[0, k], s, rate, cv, isis, pdf)
pc = plot_chi2(axs[1, k], s, contrast, freqs, chi2, fcutoff, 1.3) pc = plot_chi2(axs[1, k], s, contrast, freqs, chi2, fcutoff, 1.0)
#axs[k].set_title(f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}', fontsize='medium')
axs[1, k].text(0.95, 0.9, f'SI($r$)={nli:.1f}', ha='right', zorder=50, axs[1, k].text(0.95, 0.9, f'SI($r$)={nli:.1f}', ha='right', zorder=50,
color='white', fontsize='medium', color='white', fontsize='medium',
transform=axs[1, k].transAxes) transform=axs[1, k].transAxes)
axs[0, 0].text(0, 1.6, 'P-units:', transform=axs[0, 0].transAxes, axs[0, 0].text(0, 1.6, 'P-units:', transform=axs[0, 0].transAxes,
color=s.cell_color1, color=s.cell_color1,
fontsize='large') fontsize='large')
axs[0, -1].text(0.97, -0.45, '5\\,ms', ha='right', #axs[0, -1].text(0.97, -0.45, '5\\,ms', ha='right',
transform=axs[0, -1].transAxes) # transform=axs[0, -1].transAxes)
plot_colorbar(axs[1, -1], pc) plot_colorbar(axs[1, -1], pc)
fig.common_yticks(axs[1, :]) fig.common_yticks(axs[1, :])
fig.tag([axs[0, :]], xoffs=-3, yoffs=1) fig.tag([axs[0, :]], xoffs=-3, yoffs=1)

View File

@ -3,10 +3,11 @@ import matplotlib.pyplot as plt
from pathlib import Path from pathlib import Path
from scipy.stats import linregress from scipy.stats import linregress
from numba import jit from numba import jit
from thunderlab.tabledata import TableData
from plotstyle import plot_style, lighter, darker from plotstyle import plot_style, lighter, darker
model_cell = '2018-05-08-ad-invivo-1' # 228Hz, CV=0.67
data_path = Path('data') data_path = Path('data')
sims_path = data_path / 'simulations' sims_path = data_path / 'simulations'
@ -200,14 +201,13 @@ def plot_psd(ax, s, path, contrast, spikes, nfft, dt, beatf1, beatf2):
def plot_example(axs, axr, axp, s, path, cell, alpha, beatf1, beatf2, def plot_example(axs, axr, axp, s, path, cell, alpha, beatf1, beatf2,
nfft, trials): nfft, trials):
sim_path = path / f'{cell_name}-contrastspectrum-{1000*alpha:03.0f}.npz'
dt = 0.0001 dt = 0.0001
tmax = nfft*dt tmax = nfft*dt
t1 = 0.1 t1 = 0.1
spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials) spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials)
plot_am(axs, s, alpha, beatf1, beatf2, t1) plot_am(axs, s, alpha, beatf1, beatf2, t1)
plot_raster(axr, s, spikes, t1) plot_raster(axr, s, spikes, t1)
plot_psd(axp, s, sim_path, alpha, spikes, nfft, dt, beatf1, beatf2) plot_psd(axp, s, path, alpha, spikes, nfft, dt, beatf1, beatf2)
def peak_ampl(freqs, psd, f): def peak_ampl(freqs, psd, f):
@ -285,17 +285,16 @@ def plot_peaks(ax, s, alphas, contrasts, powerf1, powerf2, powerfsum,
if __name__ == '__main__': if __name__ == '__main__':
cell_name = '2018-05-08-ad-invivo-1' # 228Hz, CV=0.67
ratebase, cvbase, beatf1, beatf2, \ ratebase, cvbase, beatf1, beatf2, \
contrasts, powerf1, powerf2, powerfsum, powerfdiff = \ contrasts, powerf1, powerf2, powerfsum, powerfdiff = \
load_data(sims_path / f'{cell_name}-contrastpeaks.npz') load_data(sims_path / f'{model_cell}-contrastpeaks.npz')
alphas = [0.002, 0.01, 0.03, 0.06] alphas = [0.002, 0.01, 0.03, 0.06]
parameters = load_models(data_path / 'punitmodels.csv') parameters = load_models(data_path / 'punitmodels.csv')
cell = cell_parameters(parameters, cell_name) cell = cell_parameters(parameters, model_cell)
nfft = 2**18 nfft = 2**18
print(f'Loaded data for cell {cell_name}: ' print(f'Loaded data for cell {model_cell}: '
f'baseline rate = {ratebase:.0f}Hz, CV = {cvbase:.2f}') f'baseline rate = {ratebase:.0f}Hz, CV = {cvbase:.2f}')
s = plot_style() s = plot_style()
@ -308,7 +307,8 @@ if __name__ == '__main__':
# example power spectra: # example power spectra:
for c, alpha in enumerate(alphas): for c, alpha in enumerate(alphas):
plot_example(axe[0, c], axe[1, c], axe[2, c], s, sims_path, path = sims_path / f'{model_cell}-contrastspectrum-{1000*alpha:03.0f}.npz'
plot_example(axe[0, c], axe[1, c], axe[2, c], s, path,
cell, alpha, beatf1, beatf2, nfft, 100) cell, alpha, beatf1, beatf2, nfft, 100)
axe[1, 0].xscalebar(1, -0.1, 20, 'ms', ha='right') axe[1, 0].xscalebar(1, -0.1, 20, 'ms', ha='right')
axe[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8), axe[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8),
@ -322,3 +322,4 @@ if __name__ == '__main__':
powerfsum, powerfdiff) powerfsum, powerfdiff)
fig.tag(axa, yoffs=2) fig.tag(axa, yoffs=2)
fig.savefig() fig.savefig()
print()