import numpy as np import matplotlib.pyplot as plt from scipy.stats import pearsonr, gaussian_kde from scipy.stats import mannwhitneyu from thunderlab.tabledata import TableData from plotstyle import plot_style, lighter, significance_str def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color, nli_thresh): ax.axhline(nli_thresh, color='k', ls=':', lw=0.5) """ for c in np.unique(data('cell')): xdata = data[data('cell') == c, xcol] ydata = data[data('cell') == c, ycol] contrasts = data[data('cell') == c, 'contrast'] idx = np.argsort(contrasts) if len(idx) > 1: ax.plot(xdata[idx], ydata[idx], '-k', alpha=0.2, zorder=10) """ xmax = ax.get_xlim()[1] ymax = ax.get_ylim()[1] mask = (data(xcol) < xmax) & (data(ycol) < ymax) sc = ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol], s=3, clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) # color bar: fig = ax.get_figure() cax = ax.inset_axes([1.3, 0, 0.04, 1]) cax.set_spines_outward('lrbt', 0) cb = fig.colorbar(sc, cax=cax) cb.outline.set_color('none') cb.outline.set_linewidth(0) # pdf x-axis: kde = gaussian_kde(data(xcol), 0.02*xmax/np.std(data(xcol), ddof=1)) xx = np.linspace(0, ax.get_xlim()[1], 400) pdf = kde(xx) xax = ax.inset_axes([0, 1.05, 1, 0.2]) xax.show_spines('') xax.fill_between(xx, pdf, facecolor=color, edgecolor='none') #xax.plot(xx, np.zeros(len(xx)), clip_on=False, color=color, lw=0.5) xax.set_ylim(bottom=0) xax.set_ylim(0, xpdfmax) # pdf y-axis: kde = gaussian_kde(data(ycol), 0.02*ymax/np.std(data(ycol), ddof=1)) xx = np.linspace(0, ax.get_ylim()[1], 400) pdf = kde(xx) yax = ax.inset_axes([1.05, 0, 0.2, 1]) yax.show_spines('') yax.fill_betweenx(xx, pdf, facecolor=color, edgecolor='none') #yax.plot(np.zeros(len(xx)), xx, clip_on=False, color=color, lw=0.5) yax.set_xlim(left=0) # threshold: if 'cvbase' in xcol: ax.text(xmax, 0.4*ymax, f'{100*np.sum(data(ycol) > nli_thresh)/data.rows():.0f}\\%', ha='right', va='bottom', fontsize='small') ax.text(xmax, 0.3, f'{100*np.sum(data(ycol) < nli_thresh)/data.rows():.0f}\\%', ha='right', va='center', fontsize='small') # statistics: r, p = pearsonr(data(xcol), data(ycol)) ax.text(1, 0.9, f'$R={r:.2f}$ **', ha='right', transform=ax.transAxes, fontsize='small') #ax.text(1, 0.77, f'{significance_str(p)}', ha='right', # transform=ax.transAxes, fontsize='small') if 'cvbase' in xcol: ax.text(1, 0.77, f'$n={data.rows()}$', ha='right', transform=ax.transAxes, fontsize='small') print(f' correlation {xcol:<8s} - {ycol}: r={r:5.2f}, p={p:.2g}') return cax def nli_stats(title, data, column, nli_thresh): print(title) print(f' nli threshold: {nli_thresh:.1f}') nrecs = data.rows() ncells = len(np.unique(data('cell'))) print(f' cells: {ncells}') print(f' recordings: {nrecs}') hcells = np.unique(data[data(column) > nli_thresh, 'cell']) print(f' high nli cells: n={len(hcells):3d}, {100*len(hcells)/ncells:4.1f}%') print(f' high nli recordings: n={np.sum(data(column) > nli_thresh):3d}, ' f'{100*np.sum(data(column) > nli_thresh)/nrecs:4.1f}%') nsegs = data('nsegs') print(f' number of segments: {np.min(nsegs):4.0f} - {np.max(nsegs):4.0f}, median={np.median(nsegs):4.0f}, mean={np.mean(nsegs):4.0f}, std={np.std(nsegs):4.0f}') def plot_cvbase_nli_punit(ax, data, ycol, nli_thresh, color): ax.set_xlabel('CV$_{\\rm base}$') ax.set_ylabel('SI($r$)') ax.set_xlim(0, 1.5) ax.set_ylim(0, 6) ax.set_yticks_delta(2) cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 250, 3, 'coolwarm', color, nli_thresh) cax.set_ylabel('Response mod.', 'Hz') def plot_cvstim_nli_punit(ax, data, ycol, nli_thresh, color): ax.set_xlabel('CV$_{\\rm stim}$') ax.set_ylabel('SI($r$)') ax.set_xlim(0, 1.6) ax.set_ylim(0, 6) ax.set_xticks_delta(0.5) ax.set_yticks_delta(2) #cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 250, 2, # 'coolwarm', color, nli_thresh) #cax.set_ylabel('Response mod.', 'Hz') #cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 1.5, 2, # 'coolwarm', color, nli_thresh) #cax.set_ylabel('CV$_{\\rm base}$') cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 50, 450, 2, 'coolwarm', color, nli_thresh) cax.set_ylabel('$r$', 'Hz') def plot_mod_nli_punit(ax, data, ycol, nli_thresh, color): ax.set_xlabel('Response modulation', 'Hz') ax.set_ylabel('SI($r$)') ax.set_xlim(0, 250) ax.set_ylim(0, 6) ax.set_yticks_delta(2) cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 1.5, 0.016, 'coolwarm', color, nli_thresh) cax.set_ylabel('CV$_{\\rm base}$') def plot_cvbase_nli_ampul(ax, data, ycol, nli_thresh, color): ax.set_xlabel('CV$_{\\rm base}$') ax.set_ylabel('SI($r$)') ax.set_xlim(0, 0.2) ax.set_ylim(0, 15) ax.set_xticks_delta(0.1) ax.set_yticks_delta(5) cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 80, 20, 'coolwarm', color, nli_thresh) cax.set_ylabel('Response mod.', 'Hz') def plot_cvstim_nli_ampul(ax, data, ycol, nli_thresh, color): ax.set_xlabel('CV$_{\\rm stim}$') ax.set_ylabel('SI($r$)') ax.set_xlim(0, 0.85) ax.set_ylim(0, 15) ax.set_xticks_delta(0.2) ax.set_yticks_delta(5) #cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 80, 6, # 'coolwarm', color, nli_thresh) #cax.set_ylabel('Response mod.', 'Hz') #cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 0.2, 6, # 'coolwarm', color, nli_thresh) #cax.set_ylabel('CV$_{\\rm base}$') #cax.set_yticks_delta(0.1) cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 90, 180, 6, 'coolwarm', color, nli_thresh) cax.set_ylabel('$r$', 'Hz') cax.set_yticks_delta(30) def plot_mod_nli_ampul(ax, data, ycol, nli_thresh, color): ax.set_xlabel('Response modulation', 'Hz') ax.set_ylabel('SI($r$)') ax.set_xlim(0, 80) ax.set_ylim(0, 15) ax.set_xticks_delta(20) ax.set_yticks_delta(5) cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 0.2, 0.06, 'coolwarm', color, nli_thresh) cax.set_ylabel('CV$_{\\rm base}$') cax.set_yticks_delta(0.1) if __name__ == '__main__': punit_model = TableData('summarychi2noise.csv', sep=';') punit_model = punit_model[punit_model('contrast') > 1e-6, :] punit_data = TableData('Apteronotus_leptorhynchus-Punit-data.csv', sep=';') ampul_data = TableData('Apteronotus_leptorhynchus-Ampullary-data.csv') nli_thresh = 1.8 u, p = mannwhitneyu(punit_model('cvbase'), punit_data('cvbase')) print('CV differs between P-unit models and data:') print(f' U={u:g}, p={p:g}') print(f' median model: {np.median(punit_model("cvbase")):.2f}') print(f' median data: {np.median(punit_data("cvbase")):.2f}') print() u, p = mannwhitneyu(punit_model('respmod2'), punit_data('respmod2')) print('Response modulation differs between P-unit models and data:') print(f' U={u:g}, p={p:g}') print(f' median model: {np.median(punit_model("respmod2")):.2f}') print(f' median data: {np.median(punit_data("respmod2")):.2f}') print() u, p = mannwhitneyu(punit_model('dnli100'), punit_data('nli')) print('NLI does not differ between P-unit models and data:') print(f' U={u:g}, p={p:g}') print(f' median model: {np.median(punit_model("dnli100")):.1f}') print(f' median data: {np.median(punit_data("nli")):.1f}') print() s = plot_style() fig, axs = plt.subplots(3, 3, cmsize=(s.plot_width, 0.75*s.plot_width), height_ratios=[1, 0, 1, 0.3, 1]) fig.subplots_adjust(leftm=6.5, rightm=13.5, topm=4.5, bottomm=4, wspace=1.1, hspace=0.6) nli_stats('P-unit model:', punit_model, 'dnli100', nli_thresh) axs[0, 0].text(0, 1.35, 'P-unit models', transform=axs[0, 0].transAxes, color=s.model_color1) plot_cvbase_nli_punit(axs[0, 0], punit_model, 'dnli100', nli_thresh, s.model_color2) plot_mod_nli_punit(axs[0, 1], punit_model, 'dnli100', nli_thresh, s.model_color2) plot_cvstim_nli_punit(axs[0, 2], punit_model, 'dnli100', nli_thresh, s.model_color2) print() nli_stats('P-unit data:', punit_data, 'nli', nli_thresh) axs[1, 0].text(0, 1.35, 'P-unit data', transform=axs[1, 0].transAxes, color=s.punit_color1) plot_cvbase_nli_punit(axs[1, 0], punit_data, 'nli', nli_thresh, s.punit_color2) plot_mod_nli_punit(axs[1, 1], punit_data, 'nli', nli_thresh, s.punit_color2) plot_cvstim_nli_punit(axs[1, 2], punit_data, 'nli', nli_thresh, s.punit_color2) print() nli_stats('Ampullary data:', ampul_data, 'nli', nli_thresh) axs[2, 0].text(0, 1.35, 'Ampullary data', transform=axs[2, 0].transAxes, color=s.ampul_color1) plot_cvbase_nli_ampul(axs[2, 0], ampul_data, 'nli', nli_thresh, s.ampul_color2) plot_mod_nli_ampul(axs[2, 1], ampul_data, 'nli', nli_thresh, s.ampul_color2) plot_cvstim_nli_ampul(axs[2, 2], ampul_data, 'nli', nli_thresh, s.ampul_color2) print() fig.common_xticks(axs[:2, 0]) fig.common_xticks(axs[:2, 1]) fig.common_xticks(axs[:2, 2]) fig.common_yticks(axs[0, :]) fig.common_yticks(axs[1, :]) fig.common_yticks(axs[2, :]) fig.tag(xoffs=-3.5, yoffs=2) fig.savefig()