import numpy as np import matplotlib.pyplot as plt from pathlib import Path from scipy.stats import pearsonr, gaussian_kde from scipy.stats import mannwhitneyu from thunderlab.tabledata import TableData from plotstyle import plot_style, lighter, significance_str data_path = Path('data') from noisesplit import model_cell as model_split_example from modelsusceptcontrasts import model_cells as model_contrast_examples from modelsusceptlown import model_cell as model_lown_example from punitexamplecell import example_cell as punit_example from punitexamplecell import example_cells as punit_examples from noisesplit import example_cell as punit_split_example from ampullaryexamplecell import example_cell as ampul_example from ampullaryexamplecell import example_cells as ampul_examples model_examples = ([[model_lown_example, 0.01], [model_lown_example, 0.03], [model_lown_example, 0.1]], [[model_split_example, 0.01]], [[m, a] for m in model_contrast_examples for a in [0.01, 0.03, 0.1]]) punit_examples = (punit_example, [punit_split_example], punit_examples) ampul_examples = (ampul_example, [], ampul_examples) def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color, si_thresh, example=[], split_example=[], examples=[]): xdata = data[xcol] ydata = data[ycol] ax.axhline(si_thresh, color='k', ls=':', lw=0.5) xmax = ax.get_xlim()[1] ymax = ax.get_ylim()[1] mask = (xdata < xmax) & (ydata < ymax) if 'stimindex' in data: for cell, run in example + split_example + examples: mask &= ~((data['cell'] == cell) & (data['stimindex'] == run)) else: # simulations for cell, alpha in example + split_example + examples: mask &= ~((data['cell'] == cell) & (data['contrast'] == alpha)) sc = ax.scatter(xdata[mask], ydata[mask], c=data[mask, zcol], s=4, marker='o', linewidth=0, edgecolors='none', clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) elw = 0.3 if 'stimindex' in data: for cell, run in example: mask = (data['cell'] == cell) & (data['stimindex'] == run) ax.scatter(xdata[mask], ydata[mask], c=data[mask, zcol], s=6, marker='^', linewidth=elw, edgecolors='black', clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) for cell, run in split_example: mask = (data['cell'] == cell) & (data['stimindex'] == run) ax.scatter(xdata[mask], ydata[mask], c=data[mask, zcol], s=6, marker='s', linewidth=elw, edgecolors='black', clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) for cell, run in examples: mask = (data['cell'] == cell) & (data['stimindex'] == run) ax.scatter(xdata[mask], ydata[mask], c=data[mask, zcol], s=6, marker='o', linewidth=elw, edgecolors='black', clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) else: # simulations for cell, alpha in example: mask = (data['cell'] == cell) & (data['contrast'] == alpha) ax.scatter(xdata[mask], ydata[mask], c=data[mask, zcol], s=6, marker='^', linewidth=elw, edgecolors='black', clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) for cell, alpha in split_example: mask = (data['cell'] == cell) & (data['contrast'] == alpha) ax.scatter(xdata[mask], ydata[mask], c=data[mask, zcol], s=6, marker='s', linewidth=elw, edgecolors='black', clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) for cell, alpha in examples: mask = (data['cell'] == cell) & (data['contrast'] == alpha) ax.scatter(xdata[mask], ydata[mask], c=data[mask, zcol], s=6, marker='o', linewidth=elw, edgecolors='black', clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) # color bar: fig = ax.get_figure() cax = ax.inset_axes([1.3, 0, 0.04, 1]) cax.set_spines_outward('lrbt', 0) cb = fig.colorbar(sc, cax=cax) cb.outline.set_color('none') cb.outline.set_linewidth(0) # pdf x-axis: kde = gaussian_kde(xdata, 0.02*xmax/np.std(xdata, ddof=1)) xx = np.linspace(0, ax.get_xlim()[1], 400) pdf = kde(xx) xax = ax.inset_axes([0, 1.05, 1, 0.2]) xax.show_spines('') xax.fill_between(xx, pdf, facecolor=color, edgecolor='none') #xax.plot(xx, np.zeros(len(xx)), clip_on=False, color=color, lw=0.5) xax.set_ylim(bottom=0) xax.set_ylim(0, xpdfmax) # pdf y-axis: kde = gaussian_kde(ydata, 0.02*ymax/np.std(ydata, ddof=1)) xx = np.linspace(0, ax.get_ylim()[1], 400) pdf = kde(xx) yax = ax.inset_axes([1.05, 0, 0.2, 1]) yax.show_spines('') yax.fill_betweenx(xx, pdf, facecolor=color, edgecolor='none') #yax.plot(np.zeros(len(xx)), xx, clip_on=False, color=color, lw=0.5) yax.set_xlim(left=0) # threshold: if 'cvbase' in xcol: ax.text(xmax, 0.4*ymax, f'{100*np.sum(ydata > si_thresh)/len(data):.0f}\\%', ha='right', va='bottom', fontsize='small') ax.text(xmax, 0.3, f'{100*np.sum(ydata < si_thresh)/len(data):.0f}\\%', ha='right', va='center', fontsize='small') # statistics: r, p = pearsonr(xdata, ydata) ax.text(1, 0.9, f'$R={r:.2f}$', ha='right', transform=ax.transAxes, fontsize='small') ax.text(1, 0.77, f'{significance_str(p)}', ha='right', transform=ax.transAxes, fontsize='small') if 'cvbase' in xcol: ax.text(1, 0.64, f'$n={data.rows()}$', ha='right', transform=ax.transAxes, fontsize='small') print(f' correlation {xcol:<8s} - {ycol}: r={r:5.2f}, p={p:.2g}') return cax def si_stats(title, data, sicol, si_thresh, nsegscol): print(title) sidata = data[sicol] cells = np.unique(data['cell']) ncells = len(cells) nrecs = len(data) print(f' cells: {ncells}') print(f' recordings: {nrecs}') print(f' SI threshold: {si_thresh:.1f}') hcells = np.unique(data[sidata > si_thresh, 'cell']) print(f' high SI cells: n={len(hcells):3d}, {100*len(hcells)/ncells:4.1f}%') print(f' high SI recordings: n={np.sum(sidata > si_thresh):3d}, ' f'{100*np.sum(sidata > si_thresh)/nrecs:4.1f}%') nrecs = [] for cell in cells: nrecs.append(len(data[data["cell"] == cell, :])) print(f' number of recordings per cell: {np.min(nrecs):4.0f} - {np.max(nrecs):4.0f}, median={np.median(nrecs):4.0f}, mean={np.mean(nrecs):4.0f}, std={np.std(nrecs):4.0f}') fcutoff = data['fcutoff'] print(' cutoff frequencies:', ' '.join([f'{f:3.0f}Hz' for f in np.unique(fcutoff)])) print(' cutoff frequencies:', ' '.join([f'{np.sum(fcutoff == f):3d}' for f in np.unique(fcutoff)])) print(f' cutoff frequencies: {np.min(fcutoff):.0f}Hz - {np.max(fcutoff):.0f}Hz, median={np.median(fcutoff):.0f}Hz, mean={np.mean(fcutoff):.0f}Hz, std={np.std(fcutoff):.0f}Hz') print(' contrasts:') contrasts = data['contrast'] for c in np.unique(contrasts): nc = np.sum(contrasts == c) nsi = np.sum(sidata[contrasts == c] > si_thresh) print(f' {100*c:3.2g}% n={nc:3d} ({100*nc/len(contrasts):4.1f}%):' f' n={nsi:3d} ({100*nsi/nc:5.1f}%) have SI > {si_thresh:.1f}') contrasts *= 100 print(f' {np.min(contrasts):.2g}% - {np.max(contrasts):.2g}%, median={np.median(contrasts):.2g}%, mean={np.mean(contrasts):.2g}%, std={np.std(contrasts):.2g}%') nsegs = data[nsegscol] print(f' number of segments: {np.min(nsegs):4.0f} - {np.max(nsegs):4.0f}, median={np.median(nsegs):4.0f}, mean={np.mean(nsegs):4.0f}, std={np.std(nsegs):4.0f}') nsegs = data['nsegs'] print(f' available segments: {np.min(nsegs):4.0f} - {np.max(nsegs):4.0f}, median={np.median(nsegs):4.0f}, mean={np.mean(nsegs):4.0f}, std={np.std(nsegs):4.0f}') trials = data['trials'] print(f' trials: {np.min(trials):.0f} - {np.max(trials):.0f}, median={np.median(trials):.0f}, mean={np.mean(trials):.0f}, std={np.std(trials):.0f}') duration = data['duration'] print(f' duration: {np.min(duration):.1f}s - {np.max(duration):.1f}s, median={np.median(duration):.1f}s, mean={np.mean(duration):.1f}s, std={np.std(duration):.1f}s') duration *= trials print(f' total duration: {np.min(duration):.1f}s - {np.max(duration):.1f}s, median={np.median(duration):.1f}s, mean={np.mean(duration):.1f}s, std={np.std(duration):.1f}s') cols = ['cvbase', 'respmod2', 'ratebase', 'vsbase', 'serialcorr1', 'burstfrac', 'ratestim', 'cvstim'] for i in range(len(cols)): for j in range(i + 1, len(cols)): xcol = cols[i] ycol = cols[j] if xcol not in data or ycol not in data: continue r, p = pearsonr(data[xcol], data[ycol]) print(f' correlation {xcol:<11s} - {ycol:<11s}: r={r:5.2f}, p={p:.5f}') def plot_cvbase_si_punit(ax, data, ycol, si_thresh, color): ax.set_xlabel('CV$_{\\rm base}$') ax.set_xlim(0, 1.5) ax.set_xticks_delta(0.5) ax.set_ylabel('SI($r$)') ax.set_ylim(0, 6.5) ax.set_yticks_delta(2) examples = punit_examples if 'stimindex' in data else model_examples cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 250, 3, 'coolwarm', color, si_thresh, *examples) cax.set_ylabel('Response mod.', 'Hz') def plot_cvstim_si_punit(ax, data, ycol, si_thresh, color): ax.set_xlabel('CV$_{\\rm stim}$') ax.set_xlim(0, 1.6) ax.set_xticks_delta(0.5) ax.set_ylabel('SI($r$)') ax.set_ylim(0, 6.5) ax.set_yticks_delta(2) examples = punit_examples if 'stimindex' in data else model_examples #cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 250, 2, # 'coolwarm', color, si_thresh, *examples) #cax.set_ylabel('Response mod.', 'Hz') cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 1.5, 2, 'coolwarm', color, si_thresh, *examples) cax.set_ylabel('CV$_{\\rm base}$') #cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 50, 450, 2, # 'coolwarm', color, si_thresh, *examples) #cax.set_ylabel('$r$', 'Hz') #cax = plot_corr(ax, data, 'cvstim', ycol, 'serialcorr1', -0.6, 0, 2, # 'coolwarm', color, si_thresh, *examples) #cax.set_ylabel('$\\rho_1$') def plot_rmod_si_punit(ax, data, ycol, si_thresh, color): ax.set_xlabel('Response modulation', 'Hz') ax.set_xlim(0, 250) ax.set_xticks_delta(100) ax.set_ylabel('SI($r$)') ax.set_ylim(0, 6.5) ax.set_yticks_delta(2) examples = punit_examples if 'stimindex' in data else model_examples cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 1.5, 0.016, 'coolwarm', color, si_thresh, *examples) cax.set_ylabel('CV$_{\\rm base}$') def plot_rate_si_punit(ax, data, ycol, si_thresh, color): ax.set_xlabel('Baseline rate $r$', 'Hz') ax.set_xlim(0, 700) ax.set_xticks_delta(200) ax.set_ylabel('SI($r$)') ax.set_ylim(0, 6.5) ax.set_yticks_delta(2) examples = punit_examples if 'stimindex' in data else model_examples cax = plot_corr(ax, data, 'ratebase', ycol, 'cvbase', 0, 1.5, 0.016, 'coolwarm', color, si_thresh, *examples) cax.set_ylabel('CV$_{\\rm base}$') def plot_cvbase_si_ampul(ax, data, ycol, si_thresh, color): ax.set_xlabel('CV$_{\\rm base}$') ax.set_xlim(0, 0.2) ax.set_xticks_delta(0.1) ax.set_ylabel('SI($r$)') ax.set_ylim(0, 10) ax.set_yticks_delta(2) cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 80, 20, 'coolwarm', color, si_thresh, *ampul_examples) cax.set_ylabel('Response mod.', 'Hz') def plot_cvstim_si_ampul(ax, data, ycol, si_thresh, color): ax.set_xlabel('CV$_{\\rm stim}$') ax.set_xlim(0, 0.85) ax.set_xticks_delta(0.2) ax.set_ylabel('SI($r$)') ax.set_ylim(0, 10) ax.set_yticks_delta(2) #cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 80, 6, # 'coolwarm', color, si_thresh, *ampul_examples) #cax.set_ylabel('Response mod.', 'Hz') cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 0.2, 6, 'coolwarm', color, si_thresh, *ampul_examples) cax.set_ylabel('CV$_{\\rm base}$') cax.set_yticks_delta(0.1) #cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 90, 180, 6, # 'coolwarm', color, si_thresh, *ampul_examples) #cax.set_ylabel('$r$', 'Hz') #cax.set_yticks_delta(30) def plot_rmod_si_ampul(ax, data, ycol, si_thresh, color): ax.set_xlabel('Response modulation', 'Hz') ax.set_xlim(0, 80) ax.set_xticks_delta(20) ax.set_ylabel('SI($r$)') ax.set_ylim(0, 10) ax.set_yticks_delta(2) cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 0.2, 0.06, 'coolwarm', color, si_thresh, *ampul_examples) cax.set_ylabel('CV$_{\\rm base}$') cax.set_yticks_delta(0.1) def plot_rate_si_ampul(ax, data, ycol, si_thresh, color): ax.set_xlabel('Baseline rate $r$', 'Hz') ax.set_xlim(50, 200) ax.set_xticks_delta(50) ax.set_ylabel('SI($r$)') ax.set_ylim(0, 10) ax.set_yticks_delta(2) cax = plot_corr(ax, data, 'ratebase', ycol, 'cvbase', 0, 0.2, 0.06, 'coolwarm', color, si_thresh, *ampul_examples) cax.set_ylabel('CV$_{\\rm base}$') cax.set_yticks_delta(0.1) if __name__ == '__main__': punit_model = TableData(data_path / 'Apteronotus_leptorhynchus-Punit-models.csv', sep=';') punit_model = punit_model[punit_model['contrast'] > 1e-6, :] punit_data = TableData(data_path / 'Apteronotus_leptorhynchus-Punit-data.csv', sep=';') ampul_data = TableData(data_path / 'Apteronotus_leptorhynchus-Ampullary-data.csv', sep=';') #si = '' si = '_nmax' si_thresh = 1.8 cells = np.unique(list(punit_model['cell']) + list(punit_data['cell']) + list(ampul_data['cell'])) specimen = np.unique(['-'.join(c.split('-')[:3]) for c in cells]) print(f'Total of {len(cells)} cells recorded in {len(specimen)} specimens') print() cvmodel = punit_model['cvbase'] cvdata = punit_data['cvbase'] u, p = mannwhitneyu(cvmodel, cvdata) print('CV differs between P-unit models and data:') print(f' U={u:g}, p={p:.2g}') print(f' CV model: min={np.min(cvmodel):4.2f} max={np.max(cvmodel):4.2f} median={np.median(cvmodel):4.2f}') print(f' CV data: min={np.min(cvdata):4.2f} max={np.max(cvdata):.2f} median={np.median(cvdata):4.2f}') print() rmmodel = punit_model['respmod2'] rmdata = punit_data['respmod2'] u, p = mannwhitneyu(rmmodel, rmdata) print('Response modulation differs between P-unit models and data:') print(f' U={u:g}, p={p:.2g}') print(f' response modulation model: min={np.min(rmmodel):3.0f}Hz max={np.max(rmmodel):3.0f}Hz median={np.median(rmmodel):3.0f}Hz') print(f' response modulation data: min={np.min(rmdata):3.0f}Hz max={np.max(rmdata):3.0f}Hz median={np.median(rmdata):3.0f}Hz') print() simodel = punit_model['dsinorm100'] sidata = punit_data['sinorm' + si] u, p = mannwhitneyu(simodel, sidata) print('SI does not differ between P-unit models and data:') print(f' U={u:g}, p={p:.2g}') print(f' SI model: min={np.min(simodel):4.1f} max={np.max(simodel):4.1f} median={np.median(simodel):4.1f}') print(f' SI data: min={np.min(sidata):4.1f} max={np.max(sidata):4.1f} median={np.median(sidata):4.1f}') print() s = plot_style() fig, axs = plt.subplots(3, 3, cmsize=(s.plot_width, 0.75*s.plot_width), height_ratios=[1, 0, 1, 0.3, 1]) fig.subplots_adjust(leftm=6.5, rightm=13.5, topm=4.5, bottomm=4, wspace=1.1, hspace=0.6) si_stats('P-unit model:', punit_model, 'dsinorm100', si_thresh, 'nsegs100') axs[0, 0].text(0, 1.35, 'P-unit models', transform=axs[0, 0].transAxes, color=s.model_color1) plot_cvbase_si_punit(axs[0, 0], punit_model, 'dsinorm100', si_thresh, s.model_color2) plot_rmod_si_punit(axs[0, 1], punit_model, 'dsinorm100', si_thresh, s.model_color2) #plot_cvstim_si_punit(axs[0, 2], punit_model, 'dsinorm100', si_thresh, # s.model_color2) plot_rate_si_punit(axs[0, 2], punit_model, 'dsinorm100', si_thresh, s.model_color2) print() si_stats('P-unit data:', punit_data, 'sinorm' + si, si_thresh, 'nsegs' + si) axs[1, 0].text(0, 1.35, 'P-unit data', transform=axs[1, 0].transAxes, color=s.punit_color1) plot_cvbase_si_punit(axs[1, 0], punit_data, 'sinorm' + si, si_thresh, s.punit_color2) plot_rmod_si_punit(axs[1, 1], punit_data, 'sinorm' + si, si_thresh, s.punit_color2) #plot_cvstim_si_punit(axs[1, 2], punit_data, 'sinorm' + si, si_thresh, # s.punit_color2) plot_rate_si_punit(axs[1, 2], punit_data, 'sinorm' + si, si_thresh, s.punit_color2) print() si_stats('Ampullary data:', ampul_data, 'sinorm' + si, si_thresh, 'nsegs' + si) axs[2, 0].text(0, 1.35, 'Ampullary data', transform=axs[2, 0].transAxes, color=s.ampul_color1) plot_cvbase_si_ampul(axs[2, 0], ampul_data, 'sinorm' + si, si_thresh, s.ampul_color2) plot_rmod_si_ampul(axs[2, 1], ampul_data, 'sinorm' + si, si_thresh, s.ampul_color2) #plot_cvstim_si_ampul(axs[2, 2], ampul_data, 'sinorm' + si, si_thresh, # s.ampul_color2) plot_rate_si_ampul(axs[2, 2], ampul_data, 'sinorm' + si, si_thresh, s.ampul_color2) print() fig.common_xticks(axs[:2, 0]) fig.common_xticks(axs[:2, 1]) fig.common_xticks(axs[:2, 2]) fig.common_yticks(axs[0, :]) fig.common_yticks(axs[1, :]) fig.common_yticks(axs[2, :]) fig.tag(axs, xoffs=-3.5, yoffs=2) fig.savefig()