nonlinearbaseline2025/dataoverview.py

276 lines
12 KiB
Python

import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from scipy.stats import pearsonr, gaussian_kde
from scipy.stats import mannwhitneyu
from thunderlab.tabledata import TableData
from plotstyle import plot_style, lighter, significance_str
data_path = Path('data')
from punitexamplecell import example_cell as punit_example
from punitexamplecell import example_cells as punit_examples
from ampullaryexamplecell import example_cell as ampul_example
from ampullaryexamplecell import example_cells as ampul_examples
def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color,
nli_thresh, example=[], examples=[]):
ax.axhline(nli_thresh, color='k', ls=':', lw=0.5)
xmax = ax.get_xlim()[1]
ymax = ax.get_ylim()[1]
mask = (data[xcol] < xmax) & (data[ycol] < ymax)
sc = ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
s=4, marker='o', linewidth=0, edgecolors='none',
clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20)
if 'stimindex' in data:
for cell, run in example:
mask = (data['cell'] == cell) & (data['stimindex'] == run)
ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
s=6, marker='^', linewidth=0.5, edgecolors='black',
clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
zorder=20)
for cell, run in examples:
mask = (data['cell'] == cell) & (data['stimindex'] == run)
ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol],
s=5, marker='o', linewidth=0.5, edgecolors='black',
clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax,
zorder=20)
# color bar:
fig = ax.get_figure()
cax = ax.inset_axes([1.3, 0, 0.04, 1])
cax.set_spines_outward('lrbt', 0)
cb = fig.colorbar(sc, cax=cax)
cb.outline.set_color('none')
cb.outline.set_linewidth(0)
# pdf x-axis:
kde = gaussian_kde(data[xcol], 0.02*xmax/np.std(data[xcol], ddof=1))
xx = np.linspace(0, ax.get_xlim()[1], 400)
pdf = kde(xx)
xax = ax.inset_axes([0, 1.05, 1, 0.2])
xax.show_spines('')
xax.fill_between(xx, pdf, facecolor=color, edgecolor='none')
#xax.plot(xx, np.zeros(len(xx)), clip_on=False, color=color, lw=0.5)
xax.set_ylim(bottom=0)
xax.set_ylim(0, xpdfmax)
# pdf y-axis:
kde = gaussian_kde(data[ycol], 0.02*ymax/np.std(data[ycol], ddof=1))
xx = np.linspace(0, ax.get_ylim()[1], 400)
pdf = kde(xx)
yax = ax.inset_axes([1.05, 0, 0.2, 1])
yax.show_spines('')
yax.fill_betweenx(xx, pdf, facecolor=color, edgecolor='none')
#yax.plot(np.zeros(len(xx)), xx, clip_on=False, color=color, lw=0.5)
yax.set_xlim(left=0)
# threshold:
if 'cvbase' in xcol:
ax.text(xmax, 0.4*ymax, f'{100*np.sum(data[ycol] > nli_thresh)/len(data):.0f}\\%',
ha='right', va='bottom', fontsize='small')
ax.text(xmax, 0.3, f'{100*np.sum(data[ycol] < nli_thresh)/len(data):.0f}\\%',
ha='right', va='center', fontsize='small')
# statistics:
r, p = pearsonr(data[xcol], data[ycol])
ax.text(1, 0.9, f'$R={r:.2f}$ **', ha='right',
transform=ax.transAxes, fontsize='small')
#ax.text(1, 0.77, f'{significance_str(p)}', ha='right',
# transform=ax.transAxes, fontsize='small')
if 'cvbase' in xcol:
ax.text(1, 0.77, f'$n={data.rows()}$', ha='right',
transform=ax.transAxes, fontsize='small')
print(f' correlation {xcol:<8s} - {ycol}: r={r:5.2f}, p={p:.2g}')
return cax
def nli_stats(title, data, column, nli_thresh):
print(title)
print(f' nli threshold: {nli_thresh:.1f}')
cells = np.unique(data['cell'])
ncells = len(cells)
nrecs = len(data)
print(f' cells: {ncells}')
print(f' recordings: {nrecs}')
hcells = np.unique(data[data(column) > nli_thresh, 'cell'])
print(f' high nli cells: n={len(hcells):3d}, {100*len(hcells)/ncells:4.1f}%')
print(f' high nli recordings: n={np.sum(data(column) > nli_thresh):3d}, '
f'{100*np.sum(data(column) > nli_thresh)/nrecs:4.1f}%')
nsegs = data['nsegs']
print(f' number of segments: {np.min(nsegs):4.0f} - {np.max(nsegs):4.0f}, median={np.median(nsegs):4.0f}, mean={np.mean(nsegs):4.0f}, std={np.std(nsegs):4.0f}')
nrecs = []
for cell in cells:
nrecs.append(len(data[data["cell"] == cell, :]))
print(f' number of recordings per cell: {np.min(nrecs):4.0f} - {np.max(nrecs):4.0f}, median={np.median(nrecs):4.0f}, mean={np.mean(nrecs):4.0f}, std={np.std(nrecs):4.0f}')
def plot_cvbase_nli_punit(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('CV$_{\\rm base}$')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 1.5)
ax.set_ylim(0, 7.2)
ax.set_yticks_delta(2)
cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 250, 3,
'coolwarm', color, nli_thresh,
punit_example, punit_examples)
cax.set_ylabel('Response mod.', 'Hz')
def plot_cvstim_nli_punit(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('CV$_{\\rm stim}$')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 1.6)
ax.set_ylim(0, 7.2)
ax.set_xticks_delta(0.5)
ax.set_yticks_delta(2)
#cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 250, 2,
# 'coolwarm', color, nli_thresh,
# punit_example, punit_examples)
#cax.set_ylabel('Response mod.', 'Hz')
cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 1.5, 2,
'coolwarm', color, nli_thresh,
punit_example, punit_examples)
cax.set_ylabel('CV$_{\\rm base}$')
#cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 50, 450, 2,
# 'coolwarm', color, nli_thresh,
# punit_example, punit_examples)
#cax.set_ylabel('$r$', 'Hz')
#cax = plot_corr(ax, data, 'cvstim', ycol, 'serialcorr1', -0.6, 0, 2,
# 'coolwarm', color, nli_thresh,
# punit_example, punit_examples)
#cax.set_ylabel('$\\rho_1$')
def plot_mod_nli_punit(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('Response modulation', 'Hz')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 250)
ax.set_ylim(0, 7.2)
ax.set_yticks_delta(2)
cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 1.5, 0.016,
'coolwarm', color, nli_thresh,
punit_example, punit_examples)
cax.set_ylabel('CV$_{\\rm base}$')
def plot_cvbase_nli_ampul(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('CV$_{\\rm base}$')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 0.2)
ax.set_ylim(0, 15)
ax.set_xticks_delta(0.1)
ax.set_yticks_delta(5)
cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 80, 20,
'coolwarm', color, nli_thresh,
ampul_example, ampul_examples)
cax.set_ylabel('Response mod.', 'Hz')
def plot_cvstim_nli_ampul(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('CV$_{\\rm stim}$')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 0.85)
ax.set_ylim(0, 15)
ax.set_xticks_delta(0.2)
ax.set_yticks_delta(5)
#cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 80, 6,
# 'coolwarm', color, nli_thresh,
# ampul_example, ampul_examples)
#cax.set_ylabel('Response mod.', 'Hz')
cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 0.2, 6,
'coolwarm', color, nli_thresh,
ampul_example, ampul_examples)
cax.set_ylabel('CV$_{\\rm base}$')
cax.set_yticks_delta(0.1)
#cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 90, 180, 6,
# 'coolwarm', color, nli_thresh,
# ampul_example, ampul_examples)
#cax.set_ylabel('$r$', 'Hz')
#cax.set_yticks_delta(30)
def plot_mod_nli_ampul(ax, data, ycol, nli_thresh, color):
ax.set_xlabel('Response modulation', 'Hz')
ax.set_ylabel('SI($r$)')
ax.set_xlim(0, 80)
ax.set_ylim(0, 15)
ax.set_xticks_delta(20)
ax.set_yticks_delta(5)
cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 0.2, 0.06,
'coolwarm', color, nli_thresh,
ampul_example, ampul_examples)
cax.set_ylabel('CV$_{\\rm base}$')
cax.set_yticks_delta(0.1)
if __name__ == '__main__':
punit_model = TableData(data_path /
'Apteronotus_leptorhynchus-Punit-models.csv',
sep=';')
punit_model = punit_model[punit_model['contrast'] > 1e-6, :]
punit_data = TableData(data_path /
'Apteronotus_leptorhynchus-Punit-data.csv',
sep=';')
ampul_data = TableData(data_path /
'Apteronotus_leptorhynchus-Ampullary-data.csv',
sep=';')
nli_thresh = 1.8
u, p = mannwhitneyu(punit_model['cvbase'], punit_data['cvbase'])
print('CV differs between P-unit models and data:')
print(f' U={u:g}, p={p:.2g}')
print(f' median model: {np.median(punit_model["cvbase"]):.2f}')
print(f' median data: {np.median(punit_data["cvbase"]):.2f}')
print()
u, p = mannwhitneyu(punit_model['respmod2'], punit_data['respmod2'])
print('Response modulation differs between P-unit models and data:')
print(f' U={u:g}, p={p:.2g}')
print(f' median model: {np.median(punit_model["respmod2"]):.2f}')
print(f' median data: {np.median(punit_data["respmod2"]):.2f}')
print()
u, p = mannwhitneyu(punit_model['dnli100'], punit_data['nli'])
print('NLI does not differ between P-unit models and data:')
print(f' U={u:g}, p={p:.2g}')
print(f' median model: {np.median(punit_model["dnli100"]):.1f}')
print(f' median data: {np.median(punit_data["nli"]):.1f}')
print()
s = plot_style()
fig, axs = plt.subplots(3, 3, cmsize=(s.plot_width, 0.75*s.plot_width),
height_ratios=[1, 0, 1, 0.3, 1])
fig.subplots_adjust(leftm=6.5, rightm=13.5, topm=4.5, bottomm=4,
wspace=1.1, hspace=0.6)
nli_stats('P-unit model:', punit_model, 'dnli100', nli_thresh)
axs[0, 0].text(0, 1.35, 'P-unit models',
transform=axs[0, 0].transAxes, color=s.model_color1)
plot_cvbase_nli_punit(axs[0, 0], punit_model, 'dnli100', nli_thresh, s.model_color2)
plot_mod_nli_punit(axs[0, 1], punit_model, 'dnli100', nli_thresh, s.model_color2)
plot_cvstim_nli_punit(axs[0, 2], punit_model, 'dnli100', nli_thresh, s.model_color2)
print()
nli_stats('P-unit data:', punit_data, 'nli', nli_thresh)
axs[1, 0].text(0, 1.35, 'P-unit data',
transform=axs[1, 0].transAxes, color=s.punit_color1)
plot_cvbase_nli_punit(axs[1, 0], punit_data, 'nli', nli_thresh, s.punit_color2)
plot_mod_nli_punit(axs[1, 1], punit_data, 'nli', nli_thresh, s.punit_color2)
plot_cvstim_nli_punit(axs[1, 2], punit_data, 'nli', nli_thresh, s.punit_color2)
print()
nli_stats('Ampullary data:', ampul_data, 'nli', nli_thresh)
axs[2, 0].text(0, 1.35, 'Ampullary data',
transform=axs[2, 0].transAxes, color=s.ampul_color1)
plot_cvbase_nli_ampul(axs[2, 0], ampul_data, 'nli', nli_thresh, s.ampul_color2)
plot_mod_nli_ampul(axs[2, 1], ampul_data, 'nli', nli_thresh, s.ampul_color2)
plot_cvstim_nli_ampul(axs[2, 2], ampul_data, 'nli', nli_thresh, s.ampul_color2)
print()
fig.common_xticks(axs[:2, 0])
fig.common_xticks(axs[:2, 1])
fig.common_xticks(axs[:2, 2])
fig.common_yticks(axs[0, :])
fig.common_yticks(axs[1, :])
fig.common_yticks(axs[2, :])
fig.tag(xoffs=-3.5, yoffs=2)
fig.savefig()
print()