diff --git a/ampullaryexamplecell.py b/ampullaryexamplecell.py new file mode 100644 index 0000000..1c1a613 --- /dev/null +++ b/ampullaryexamplecell.py @@ -0,0 +1,225 @@ +import sys +sys.path.insert(0, 'ephys') # for analysing data +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from spectral import diag_projection, peakedness +from plotstyle import plot_style +from punitexamplecell import load_baseline, load_noise, load_spectra +from punitexamplecell import plot_colorbar + + +cell_name = '2012-05-15-ac' +run1 = 3 # 4 +run2 = 1 + +base_path = Path('ephys') +data_path = base_path / 'data' +results_path = base_path / 'results' + + +def plot_isih(ax, s, rate, cv, isis, pdf): + ax.show_spines('b') + ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1) + ax.set_xlim(0, 12) + ax.set_xticks_delta(4) + ax.set_xlabel('ISI', 'ms') + ax.text(0, 1.08, 'Ampullary:', transform=ax.transAxes, color=s.cell_color1, + fontsize='large') + ax.text(0.95, 1.08, f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}', + transform=ax.transAxes) + + +def plot_response_spectrum(ax, s, eodf, rate, freqs, prr): + rate_i = np.argmax(prr[freqs < 0.7*eodf]) + eod_i = np.argmax(prr[freqs > 500]) + np.argmax(freqs > 500) + power_db = 10*np.log10(prr/np.max(prr)) + ax.show_spines('b') + mask = (freqs > 30) & (freqs < 890) + ax.plot(freqs[mask], power_db[mask], **s.lsC1) + ax.plot(freqs[eod_i], power_db[eod_i], **s.psA1c) + ax.plot(freqs[rate_i], power_db[rate_i], **s.psA2c) + ax.set_xlim(0, 900) + ax.set_ylim(-25, 5) + ax.set_xticks_delta(300) + ax.set_xlabel('$f$', 'Hz') + ax.text(freqs[eod_i], power_db[eod_i] + 2, '$f_{\\rm EOD}$', + ha='center') + ax.text(freqs[rate_i], power_db[rate_i] + 2, '$r$', + ha='center') + ax.yscalebar(1.05, 0, 10, 'dB', ha='right') + + +def plot_response(ax, s, eodf, time1, stimulus1, contrast1, spikes1, contrast2, spikes2): + t0 = 0.3 + t1 = 0.4 + maxtrials = 8 + trials = np.arange(maxtrials) + ax.show_spines('') + ax.eventplot(spikes1[2:2+maxtrials], lineoffsets=trials - maxtrials + 1, + linelength=0.8, linewidths=1, color=s.cell_color1) + ax.eventplot(spikes2[2:2+maxtrials], lineoffsets=trials - 2*maxtrials, + linelength=0.8, linewidths=1, color=s.cell_color2) + am = contrast1*stimulus1 + eod = np.sin(2*np.pi*eodf*time1) + am + ax.plot(time1, 4*eod + 7, **s.lsEOD) + ax.plot(time1, 4*(1 + am) + 7, **s.lsAM) + ax.set_xlim(t0, t1) + ax.set_ylim(-2*maxtrials - 0.5, 14) + ax.xscalebar(1, -0.05, 0.01, None, '10\\,ms', ha='right') + ax.text(t1 + 0.003, -0.5*maxtrials, f'${100*contrast1:.0f}$\\,\\%', + va='center', color=s.cell_color1) + ax.text(t1 + 0.003, -1.55*maxtrials, f'${100*contrast2:.0f}$\\,\\%', + va='center', color=s.cell_color2) + + +def plot_gain(ax, s, fbase, contrast1, freqs1, gain1, + contrast2, freqs2, gain2, fcutoff): + ax.axvline(fbase, **s.lsGrid) + ax.plot(freqs2, gain2, label=f'{100*contrast2:.0f}', **s.lsC2) + ax.plot(freqs1, gain1, label=f'{100*contrast1:.0f}', **s.lsC1) + ax.set_xlim(0, fcutoff) + ax.set_ylim(0, 1500) + ax.set_xticks_delta(50) + ax.set_xlabel('$f$', 'Hz') + ax.set_ylabel(r'$|\chi_1|$', 'Hz') + ax.text(fbase, 1550, '$r$', ha='center') + + +def plot_chi2(ax, s, contrast, freqs, chi2, fcutoff, vmax): + ax.set_aspect('equal') + if vmax is None: + vmax = np.quantile(1e-3*chi2, 0.99) + pc = ax.pcolormesh(freqs, freqs, 1e-3*chi2, vmin=0, vmax=vmax, + cmap='viridis', rasterized=True, zorder=10) + ax.set_xlim(0, fcutoff) + ax.set_ylim(0, fcutoff) + df = 100 if fcutoff == 300 else 50 + ax.set_xticks_delta(df) + ax.set_yticks_delta(df) + ax.set_xlabel('$f_1$', 'Hz') + ax.set_ylabel('$f_2$', 'Hz') + return pc + + +def plot_diagonals(ax, s, fbase, contrast1, freqs1, chi21, contrast2, freqs2, chi22, fcutoff): + diags = [] + nlis = [] + nlips = [] + nlifs = [] + for contrast, freqs, chi2 in [[contrast1, freqs1, chi21], [contrast2, freqs2, chi22]]: + dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff) + diags.append([dfreqs, diag]) + nli, nlif = peakedness(dfreqs, diag, fbase, median=False) + nlip = diag[np.argmin(np.abs(dfreqs - nlif))] + nlis.append(nli) + nlips.append(nlip) + nlifs.append(nlif) + print(f' SI at {100*contrast:.1f}% contrast: {nli:.2f}') + ax.axvline(fbase, **s.lsGrid) + ax.plot(diags[1][0], 1e-3*diags[1][1], **s.lsC2) + ax.plot(diags[0][0], 1e-3*diags[0][1], **s.lsC1) + ax.plot(nlifs[1], 1e-3*nlips[1], **s.psC2) + ax.plot(nlifs[0], 1e-3*nlips[0], **s.psC1) + ax.set_xlim(0, 2*fcutoff) + ax.set_ylim(0, 1.7) + ax.set_xticks_delta(100) + ax.set_yticks_delta(1) + ax.set_xlabel('$f_1 + f_2$', 'Hz') + #ax.set_ylabel(r'$|\chi_2|$', 'kHz') + ax.text(nlifs[1] - 25, 1e-3*nlips[1], f'{100*contrast2:.0f}\\%', + ha='right') + ax.text(nlifs[1] + 35, 1e-3*nlips[1], f'SI={nlis[1]:.1f}') + ax.text(nlifs[0] - 25, 1e-3*nlips[0], f'{100*contrast1:.0f}\\%', + ha='right') + ax.text(nlifs[0] + 35, 1e-3*nlips[0], f'SI={nlis[0]:.1f}') + ax.text(fbase, 1.75, '$r$', ha='center') + + +if __name__ == '__main__': + """ + from thunderlab.tabledata import TableData + data = TableData('Apteronotus_leptorhynchus-Ampullary-data.csv') + data = data[(data('fcutoff') > 140) & (data('fcutoff') < 160), :] + data = data[(data('nli') > 2) & (data('nli') < 2.5), :] + data = data[(data('respmod2') > 20) & (data('respmod2') < 100), :] + data = data[(data('cvbase') > 0.05) & (data('cvbase') < 0.2), :] + data = data[(data('ratebase') > 100) & (data('ratebase') < 180), :] + for k in range(data.rows()): + print(f'{data[k, "cell"]:<22s} s{data[k, "stimindex"]:02.0f}: {100*data[k, "contrast"]:3g}%, {data[k, "respmod2"]:3.0f}Hz, nli={data[k, "nli"]:5.2f}') + print() + #exit() + """ + + print('Example Ampullary cell:', cell_name) + eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(results_path, cell_name) + print(f' baseline firing rate: {rate:.0f}Hz') + print(f' baseline firing CV : {cv:.2f}') + contrast1, time1, stimulus1, spikes1 = load_noise(data_path, cell_name, run1) + contrast2, time2, stimulus2, spikes2 = load_noise(data_path, cell_name, run2) + fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(results_path, cell_name, run1) + fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(results_path, cell_name, run2) + + s = plot_style() + s.cell_color1 = s.ampul_color1 + s.cell_color2 = s.ampul_color2 + s.lsC1 = s.lsA1 + s.lsC2 = s.lsA2 + s.psC1 = s.psA1 + s.psC2 = s.psA2 + fig, (ax1, ax2, ax3) = plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.5, 3], + cmsize=(s.plot_width, 0.8*s.plot_width)) + fig.subplots_adjust(leftm=8, rightm=9, topm=2, bottomm=4, + wspace=0.4, hspace=0.5) + axi, axp, axr = ax1.subplots(1, 3, width_ratios=[2, 3, 0, 10]) + axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.4) + axg = axg.subplots(1, 1, width_ratios=[1, 0.1]) + axd = axd.subplots(1, 1, width_ratios=[0.2, 1]) + axs = ax3.subplots(1, 4, wspace=0.4) + + plot_isih(axi, s, rate, cv, isis, pdf) + plot_response_spectrum(axp, s, eodf, rate, freqs, prr) + plot_response(axr, s, eodf, time1, stimulus1, contrast1, spikes1, + contrast2, spikes2) + + plot_gain(axg, s, rate, contrast1, freqs1, gain1, + contrast2, freqs2, gain2, fcutoff1) + pc = plot_chi2(axc1, s, contrast2, freqs2, chi22, fcutoff2, 1.7) + axc1.plot([0, fcutoff2], [0, fcutoff2], zorder=20, **s.lsDiag) + axc1.set_title(f'$c$={100*contrast2:g}\\,\\%', + fontsize='medium', color=s.cell_color2) + pc = plot_chi2(axc2, s, contrast1, freqs1, chi21, fcutoff1, 1.7) + axc2.set_title(f'$c$={100*contrast1:g}\\,\\%', + fontsize='medium', color=s.cell_color1) + axc2.plot([0, fcutoff1], [0, fcutoff1], zorder=20, **s.lsDiag) + plot_colorbar(axc2, pc, 1) + plot_diagonals(axd, s, rate, contrast1, freqs1, chi21, + contrast2, freqs2, chi22, fcutoff1) + + fig.common_yticks(axc1, axc2) + fig.tag([axi, axp, axr], xoffs=-3, yoffs=-1) + fig.tag([axg, axc1, axc2, axd], xoffs=-3, yoffs=2) + + print('Additional example cells:') + example_cells = [ + ['2010-11-26-an', 0], + ['2011-10-25-ac', 0], + ['2011-02-18-ab', 1], + ['2014-01-16-aj', 5], + ] + for k, (cell, run) in enumerate(example_cells): + eodf, rate, cv, _, _, _, _ = load_baseline(results_path, cell) + fcutoff, contrast, freqs, gain, chi2 = load_spectra(results_path, cell, run) + dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff) + nli, nlif = peakedness(dfreqs, diag, rate, median=False) + print(f' {cell:<22s}: run={run:2d}, fbase={rate:3.0f}Hz, CV={cv:.2f}, SI={nli:3.1f}') + pc = plot_chi2(axs[k], s, contrast, freqs, chi2, fcutoff, 1.2) + axs[k].set_title(f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}', fontsize='medium') + axs[k].text(0.95, 0.9, f'SI($r$)={nli:.1f}', ha='right', zorder=50, + color='white', fontsize='medium', + transform=axs[k].transAxes) + plot_colorbar(axs[-1], pc, 0.4) + fig.common_yticks(axs) + fig.tag(axs, xoffs=-3, yoffs=2) + + fig.savefig() diff --git a/dataoverview.py b/dataoverview.py new file mode 100644 index 0000000..403b7f9 --- /dev/null +++ b/dataoverview.py @@ -0,0 +1,234 @@ +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import pearsonr, gaussian_kde +from scipy.stats import mannwhitneyu +from thunderlab.tabledata import TableData +from plotstyle import plot_style, lighter, significance_str + + +def plot_corr(ax, data, xcol, ycol, zcol, zmin, zmax, xpdfmax, cmap, color, + nli_thresh): + ax.axhline(nli_thresh, color='k', ls=':', lw=0.5) + """ + for c in np.unique(data('cell')): + xdata = data[data('cell') == c, xcol] + ydata = data[data('cell') == c, ycol] + contrasts = data[data('cell') == c, 'contrast'] + idx = np.argsort(contrasts) + if len(idx) > 1: + ax.plot(xdata[idx], ydata[idx], '-k', alpha=0.2, zorder=10) + """ + xmax = ax.get_xlim()[1] + ymax = ax.get_ylim()[1] + mask = (data(xcol) < xmax) & (data(ycol) < ymax) + sc = ax.scatter(data[mask, xcol], data[mask, ycol], c=data[mask, zcol], + s=3, clip_on=False, cmap=cmap, vmin=zmin, vmax=zmax, zorder=20) + # color bar: + fig = ax.get_figure() + cax = ax.inset_axes([1.3, 0, 0.04, 1]) + cax.set_spines_outward('lrbt', 0) + cb = fig.colorbar(sc, cax=cax) + cb.outline.set_color('none') + cb.outline.set_linewidth(0) + # pdf x-axis: + kde = gaussian_kde(data(xcol), 0.02*xmax/np.std(data(xcol), ddof=1)) + xx = np.linspace(0, ax.get_xlim()[1], 400) + pdf = kde(xx) + xax = ax.inset_axes([0, 1.05, 1, 0.2]) + xax.show_spines('') + xax.fill_between(xx, pdf, facecolor=color, edgecolor='none') + #xax.plot(xx, np.zeros(len(xx)), clip_on=False, color=color, lw=0.5) + xax.set_ylim(bottom=0) + xax.set_ylim(0, xpdfmax) + # pdf y-axis: + kde = gaussian_kde(data(ycol), 0.02*ymax/np.std(data(ycol), ddof=1)) + xx = np.linspace(0, ax.get_ylim()[1], 400) + pdf = kde(xx) + yax = ax.inset_axes([1.05, 0, 0.2, 1]) + yax.show_spines('') + yax.fill_betweenx(xx, pdf, facecolor=color, edgecolor='none') + #yax.plot(np.zeros(len(xx)), xx, clip_on=False, color=color, lw=0.5) + yax.set_xlim(left=0) + # threshold: + if 'cvbase' in xcol: + ax.text(xmax, 0.4*ymax, f'{100*np.sum(data(ycol) > nli_thresh)/data.rows():.0f}\\%', + ha='right', va='bottom', fontsize='small') + ax.text(xmax, 0.3, f'{100*np.sum(data(ycol) < nli_thresh)/data.rows():.0f}\\%', + ha='right', va='center', fontsize='small') + # statistics: + r, p = pearsonr(data(xcol), data(ycol)) + ax.text(1, 0.9, f'$R={r:.2f}$ **', ha='right', + transform=ax.transAxes, fontsize='small') + #ax.text(1, 0.77, f'{significance_str(p)}', ha='right', + # transform=ax.transAxes, fontsize='small') + if 'cvbase' in xcol: + ax.text(1, 0.77, f'$n={data.rows()}$', ha='right', + transform=ax.transAxes, fontsize='small') + print(f' correlation {xcol:<8s} - {ycol}: r={r:5.2f}, p={p:.2g}') + return cax + + +def nli_stats(title, data, column, nli_thresh): + print(title) + print(f' nli threshold: {nli_thresh:.1f}') + nrecs = data.rows() + ncells = len(np.unique(data('cell'))) + print(f' cells: {ncells}') + print(f' recordings: {nrecs}') + hcells = np.unique(data[data(column) > nli_thresh, 'cell']) + print(f' high nli cells: n={len(hcells):3d}, {100*len(hcells)/ncells:4.1f}%') + print(f' high nli recordings: n={np.sum(data(column) > nli_thresh):3d}, ' + f'{100*np.sum(data(column) > nli_thresh)/nrecs:4.1f}%') + nsegs = data('nsegs') + print(f' number of segments: {np.min(nsegs):4.0f} - {np.max(nsegs):4.0f}, median={np.median(nsegs):4.0f}, mean={np.mean(nsegs):4.0f}, std={np.std(nsegs):4.0f}') + + +def plot_cvbase_nli_punit(ax, data, ycol, nli_thresh, color): + ax.set_xlabel('CV$_{\\rm base}$') + ax.set_ylabel('SI($r$)') + ax.set_xlim(0, 1.5) + ax.set_ylim(0, 6) + ax.set_yticks_delta(2) + cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 250, 3, + 'coolwarm', color, nli_thresh) + cax.set_ylabel('Response mod.', 'Hz') + + +def plot_cvstim_nli_punit(ax, data, ycol, nli_thresh, color): + ax.set_xlabel('CV$_{\\rm stim}$') + ax.set_ylabel('SI($r$)') + ax.set_xlim(0, 1.6) + ax.set_ylim(0, 6) + ax.set_xticks_delta(0.5) + ax.set_yticks_delta(2) + #cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 250, 2, + # 'coolwarm', color, nli_thresh) + #cax.set_ylabel('Response mod.', 'Hz') + #cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 1.5, 2, + # 'coolwarm', color, nli_thresh) + #cax.set_ylabel('CV$_{\\rm base}$') + cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 50, 450, 2, + 'coolwarm', color, nli_thresh) + cax.set_ylabel('$r$', 'Hz') + + +def plot_mod_nli_punit(ax, data, ycol, nli_thresh, color): + ax.set_xlabel('Response modulation', 'Hz') + ax.set_ylabel('SI($r$)') + ax.set_xlim(0, 250) + ax.set_ylim(0, 6) + ax.set_yticks_delta(2) + cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 1.5, 0.016, + 'coolwarm', color, nli_thresh) + cax.set_ylabel('CV$_{\\rm base}$') + + +def plot_cvbase_nli_ampul(ax, data, ycol, nli_thresh, color): + ax.set_xlabel('CV$_{\\rm base}$') + ax.set_ylabel('SI($r$)') + ax.set_xlim(0, 0.2) + ax.set_ylim(0, 15) + ax.set_xticks_delta(0.1) + ax.set_yticks_delta(5) + cax = plot_corr(ax, data, 'cvbase', ycol, 'respmod2', 0, 80, 20, + 'coolwarm', color, nli_thresh) + cax.set_ylabel('Response mod.', 'Hz') + + +def plot_cvstim_nli_ampul(ax, data, ycol, nli_thresh, color): + ax.set_xlabel('CV$_{\\rm stim}$') + ax.set_ylabel('SI($r$)') + ax.set_xlim(0, 0.85) + ax.set_ylim(0, 15) + ax.set_xticks_delta(0.2) + ax.set_yticks_delta(5) + #cax = plot_corr(ax, data, 'cvstim', ycol, 'respmod2', 0, 80, 6, + # 'coolwarm', color, nli_thresh) + #cax.set_ylabel('Response mod.', 'Hz') + #cax = plot_corr(ax, data, 'cvstim', ycol, 'cvbase', 0, 0.2, 6, + # 'coolwarm', color, nli_thresh) + #cax.set_ylabel('CV$_{\\rm base}$') + #cax.set_yticks_delta(0.1) + cax = plot_corr(ax, data, 'cvstim', ycol, 'ratebase', 90, 180, 6, + 'coolwarm', color, nli_thresh) + cax.set_ylabel('$r$', 'Hz') + cax.set_yticks_delta(30) + + +def plot_mod_nli_ampul(ax, data, ycol, nli_thresh, color): + ax.set_xlabel('Response modulation', 'Hz') + ax.set_ylabel('SI($r$)') + ax.set_xlim(0, 80) + ax.set_ylim(0, 15) + ax.set_xticks_delta(20) + ax.set_yticks_delta(5) + cax = plot_corr(ax, data, 'respmod2', ycol, 'cvbase', 0, 0.2, 0.06, + 'coolwarm', color, nli_thresh) + cax.set_ylabel('CV$_{\\rm base}$') + cax.set_yticks_delta(0.1) + + +if __name__ == '__main__': + punit_model = TableData('summarychi2noise.csv', sep=';') + punit_model = punit_model[punit_model('contrast') > 1e-6, :] + punit_data = TableData('Apteronotus_leptorhynchus-Punit-data.csv', sep=';') + ampul_data = TableData('Apteronotus_leptorhynchus-Ampullary-data.csv') + nli_thresh = 1.8 + + u, p = mannwhitneyu(punit_model('cvbase'), punit_data('cvbase')) + print('CV differs between P-unit models and data:') + print(f' U={u:g}, p={p:g}') + print(f' median model: {np.median(punit_model("cvbase")):.2f}') + print(f' median data: {np.median(punit_data("cvbase")):.2f}') + print() + u, p = mannwhitneyu(punit_model('respmod2'), punit_data('respmod2')) + print('Response modulation differs between P-unit models and data:') + print(f' U={u:g}, p={p:g}') + print(f' median model: {np.median(punit_model("respmod2")):.2f}') + print(f' median data: {np.median(punit_data("respmod2")):.2f}') + print() + u, p = mannwhitneyu(punit_model('dnli100'), punit_data('nli')) + print('NLI does not differ between P-unit models and data:') + print(f' U={u:g}, p={p:g}') + print(f' median model: {np.median(punit_model("dnli100")):.1f}') + print(f' median data: {np.median(punit_data("nli")):.1f}') + print() + + s = plot_style() + fig, axs = plt.subplots(3, 3, cmsize=(s.plot_width, 0.75*s.plot_width), + height_ratios=[1, 0, 1, 0.3, 1]) + fig.subplots_adjust(leftm=6.5, rightm=13.5, topm=4.5, bottomm=4, + wspace=1.1, hspace=0.6) + + nli_stats('P-unit model:', punit_model, 'dnli100', nli_thresh) + axs[0, 0].text(0, 1.35, 'P-unit models', + transform=axs[0, 0].transAxes, color=s.model_color1) + plot_cvbase_nli_punit(axs[0, 0], punit_model, 'dnli100', nli_thresh, s.model_color2) + plot_mod_nli_punit(axs[0, 1], punit_model, 'dnli100', nli_thresh, s.model_color2) + plot_cvstim_nli_punit(axs[0, 2], punit_model, 'dnli100', nli_thresh, s.model_color2) + print() + + nli_stats('P-unit data:', punit_data, 'nli', nli_thresh) + axs[1, 0].text(0, 1.35, 'P-unit data', + transform=axs[1, 0].transAxes, color=s.punit_color1) + plot_cvbase_nli_punit(axs[1, 0], punit_data, 'nli', nli_thresh, s.punit_color2) + plot_mod_nli_punit(axs[1, 1], punit_data, 'nli', nli_thresh, s.punit_color2) + plot_cvstim_nli_punit(axs[1, 2], punit_data, 'nli', nli_thresh, s.punit_color2) + print() + + nli_stats('Ampullary data:', ampul_data, 'nli', nli_thresh) + axs[2, 0].text(0, 1.35, 'Ampullary data', + transform=axs[2, 0].transAxes, color=s.ampul_color1) + plot_cvbase_nli_ampul(axs[2, 0], ampul_data, 'nli', nli_thresh, s.ampul_color2) + plot_mod_nli_ampul(axs[2, 1], ampul_data, 'nli', nli_thresh, s.ampul_color2) + plot_cvstim_nli_ampul(axs[2, 2], ampul_data, 'nli', nli_thresh, s.ampul_color2) + print() + + fig.common_xticks(axs[:2, 0]) + fig.common_xticks(axs[:2, 1]) + fig.common_xticks(axs[:2, 2]) + fig.common_yticks(axs[0, :]) + fig.common_yticks(axs[1, :]) + fig.common_yticks(axs[2, :]) + fig.tag(xoffs=-3.5, yoffs=2) + fig.savefig() diff --git a/modelsusceptcontrasts.py b/modelsusceptcontrasts.py new file mode 100644 index 0000000..9667ba1 --- /dev/null +++ b/modelsusceptcontrasts.py @@ -0,0 +1,183 @@ +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import pearsonr, linregress, gaussian_kde +from thunderlab.tabledata import TableData +from pathlib import Path +from plotstyle import plot_style, labels_params, significance_str + + +data_path = Path('newdata3') + + +def sort_files(cell_name, all_files, n): + files = [fn for fn in all_files if '-'.join(fn.stem.split('-')[2:-n]) == cell_name] + if len(files) == 0: + return None, 0 + nums = [int(fn.stem.split('-')[-1]) for fn in files] + idxs = np.argsort(nums) + files = [files[i] for i in idxs] + nums = [nums[i] for i in idxs] + return files, nums + + +def plot_chi2(ax, s, data_file): + data = np.load(data_file) + n = data['n'] + alpha = data['alpha'] + freqs = data['freqs'] + pss = data['pss'] + dt_fix = 1 # 0.0005 + prss = np.abs(data['prss'])/dt_fix*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1)) + ax.set_visible(True) + ax.set_aspect('equal') + i0 = np.argmin(freqs < -300) + i0 = np.argmin(freqs < 0) + i1 = np.argmax(freqs > 300) + if i1 == 0: + i1 = len(freqs) + freqs = freqs[i0:i1] + prss = prss[i0:i1, i0:i1] + vmax = np.quantile(prss, 0.996) + ten = 10**np.floor(np.log10(vmax)) + for fac, delta in zip([1, 2, 3, 4, 6, 8, 10], + [0.5, 1, 1, 2, 3, 4, 5]): + if fac*ten >= vmax: + vmax = fac*ten + ten *= delta + break + pc = ax.pcolormesh(freqs, freqs, prss, vmin=0, vmax=vmax, + cmap='viridis', rasterized=True) + if 'noise_frac' in data: + ax.set_title('$c$=0\\,\\%', fontsize='medium') + else: + ax.set_title(f'$c$={100*alpha:g}\\,\\%', fontsize='medium') + ax.set_xlim(0, 300) + ax.set_ylim(0, 300) + ax.set_xticks_delta(100) + ax.set_yticks_delta(100) + ax.set_xlabel('$f_1$', 'Hz') + ax.set_ylabel('$f_2$', 'Hz') + cax = ax.inset_axes([1.04, 0, 0.05, 1]) + cax.set_spines_outward('lrbt', 0) + if alpha == 0.1: + cb = fig.colorbar(pc, cax=cax, label=r'$|\chi_2|$ [Hz]') + else: + cb = fig.colorbar(pc, cax=cax) + cb.outline.set_color('none') + cb.outline.set_linewidth(0) + cax.set_yticks_delta(ten) + + +def plot_chi2_contrasts(axs, s, cell_name): + print(cell_name) + files, nums = sort_files(cell_name, + data_path.glob(f'chi2-split-{cell_name}-*.npz'), 1) + plot_chi2(axs[0], s, files[-1]) + for k, alphastr in enumerate(['010', '030', '100']): + files, nums = sort_files(cell_name, + data_path.glob(f'chi2-noisen-{cell_name}-{alphastr}-*.npz'), 2) + plot_chi2(axs[k + 1], s, files[-1]) + + +def plot_nli_cv(ax, s, data, alpha, cells): + data = data[data('contrast') == alpha, :] + r, p = pearsonr(data('cvbase'), data[:, 'dnli']) + l = linregress(data('cvbase'), data[:, 'dnli']) + x = np.linspace(0, 1, 10) + ax.set_visible(True) + ax.set_title(f'$c$={100*alpha:g}\\,\\%', fontsize='medium') + ax.axhline(1, **s.lsLine) + ax.plot(x, l.slope*x + l.intercept, **s.lsGrid) + mask = data('triangle') > 0.5 + ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'], + clip_on=False, zorder=30, label='strong', **s.psA1m) + mask = data[:, 'border'] > 0.5 + ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'], + zorder=20, label='weak', **s.psA2m) + ax.plot(data[:, 'cvbase'], data[:, 'dnli'], clip_on=False, + zorder=10, label='none', **s.psB1m) + + for cell_name in cells: + mask = data[:, 'cell'] == cell_name + color = s.psB1m['color'] + if data[mask, 'border']: + color = s.psA2m['color'] + elif data[mask, 'triangle']: + color = s.psA1m['color'] + ax.plot(data[mask, 'cvbase'], data[mask, 'dnli'], + zorder=40, marker='o', ms=s.psB1m['markersize'], + mfc=color, mec='k', mew=0.8) + + ax.set_ylim(0, 8) + ax.set_xlim(0, 1) + ax.set_minor_yticks_delta(1) + ax.set_xlabel('CV$_{\\rm base}$') + ax.set_ylabel('SI') + ax.set_yticks_delta(4) + ax.text(1, 0.9, f'$r={r:.2f}$', transform=ax.transAxes, + ha='right', fontsize='small') + ax.text(1, 0.7, significance_str(p), transform=ax.transAxes, + ha='right', fontsize='small') + if alpha == 0: + ax.legend(loc='upper left', bbox_to_anchor=(1.15, 1.05), + title='triangle', handlelength=0.5, + handletextpad=0.5, labelspacing=0.2) + + kde = gaussian_kde(data('dnli'), 0.15/np.std(data('dnli'), ddof=1)) + nli = np.linspace(0, 8, 100) + pdf = kde(nli) + dax = ax.inset_axes([1.04, 0, 0.3, 1]) + dax.show_spines('') + dax.fill_betweenx(nli, pdf, **s.fsB1a) + dax.plot(pdf, nli, clip_on=False, **s.lsB1m) + + +def plot_summary_contrasts(axs, s, cells): + nli_thresh = 1.2 + data = TableData('summarychi2noise.csv') + plot_nli_cv(axs[0], s, data, 0, cells) + print('split:') + nli_split = data[data('contrast') == 0, 'dnli'] + print(f' mean NLI = {np.mean(nli_split):.2f}, stdev = {np.std(nli_split):.2f}') + n = np.sum(nli_split > nli_thresh) + print(f' {n} cells ({100*n/len(nli_split):.1f}%) have NLI > {nli_thresh:.1f}') + print(f' triangle cells have nli >= {np.min(nli_split[data[data("contrast") == 0, "triangle"] > 0.5])}') + print() + for i, a in enumerate([0.01, 0.03, 0.1]): + plot_nli_cv(axs[1 + i], s, data, a, cells) + print(f'contrast {100*a:2g}%:') + cdata = data[data('contrast') == a, :] + nli = cdata('dnli') + r, p = pearsonr(nli_split, nli) + print(f' correlation with split: r={r:.2f}, p={p:.1e}') + print(f' mean NLI = {np.mean(nli):.2f}, stdev = {np.std(nli):.2f}') + n = np.sum(nli > nli_thresh) + print(f' {n} cells ({100*n/len(nli):.1f}%) have NLI > {nli_thresh:.1f}') + print( ' CVs:', cdata[nli > nli_thresh, 'cvbase']) + print( ' names:', cdata[nli > nli_thresh, 'cell']) + print() + print('lowest baseline CV:', np.unique(data('cvbase'))[:3]) + + +if __name__ == '__main__': + cells = ['2017-07-18-ai-invivo-1', # strong triangle + '2012-12-13-ao-invivo-1', # triangle + '2012-12-20-ac-invivo-1', # weak border triangle + '2013-01-08-ab-invivo-1'] # no triangle + s = plot_style() + #labels_params(xlabelloc='right', ylabelloc='top') + fig, axs = plt.subplots(6, 4, cmsize=(s.plot_width, 0.95*s.plot_width), + height_ratios=[1, 1, 1, 1, 0, 1]) + fig.subplots_adjust(leftm=7, rightm=8, topm=2, bottomm=3.5, + wspace=1, hspace=0.7) + for ax in axs.flat: + ax.set_visible(False) + for k in range(len(cells)): + plot_chi2_contrasts(axs[k], s, cells[k]) + for k in range(4): + fig.common_yticks(axs[k, :]) + fig.common_xticks(axs[:4, k]) + plot_summary_contrasts(axs[5], s, cells) + fig.common_yticks(axs[5, :]) + fig.tag(axs, xoffs=-4.5, yoffs=1.8) + fig.savefig() diff --git a/modelsusceptlown.py b/modelsusceptlown.py new file mode 100644 index 0000000..3f053e8 --- /dev/null +++ b/modelsusceptlown.py @@ -0,0 +1,199 @@ +import numpy as np +import matplotlib.pyplot as plt +from scipy.stats import pearsonr, linregress, gaussian_kde +from thunderlab.tabledata import TableData +from pathlib import Path +from plotstyle import plot_style, labels_params, significance_str + + +data_path = Path('newdata3') + + +def sort_files(cell_name, all_files, n): + files = [fn for fn in all_files if '-'.join(fn.stem.split('-')[2:-n]) == cell_name] + if len(files) == 0: + return None, 0 + nums = [int(fn.stem.split('-')[-1]) for fn in files] + idxs = np.argsort(nums) + files = [files[i] for i in idxs] + nums = [nums[i] for i in idxs] + return files, nums + + +def plot_chi2(ax, s, data_file): + data = np.load(data_file) + n = data['n'] + alpha = data['alpha'] + freqs = data['freqs'] + pss = data['pss'] + dt_fix = 1 # 0.0005 + prss = np.abs(data['prss'])/dt_fix*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1)) + ax.set_visible(True) + ax.set_aspect('equal') + i0 = np.argmin(freqs < -300) + i0 = np.argmin(freqs < 0) + i1 = np.argmax(freqs > 300) + if i1 == 0: + i1 = len(freqs) + freqs = freqs[i0:i1] + prss = prss[i0:i1, i0:i1] + vmax = np.quantile(prss, 0.996) + ten = 10**np.floor(np.log10(vmax)) + for fac, delta in zip([1, 2, 3, 4, 6, 8, 10], + [0.5, 1, 1, 2, 3, 4, 5]): + if fac*ten >= vmax: + vmax = fac*ten + ten *= delta + break + pc = ax.pcolormesh(freqs, freqs, prss, vmin=0, vmax=vmax, + cmap='viridis', rasterized=True) + ns = f'$N={n}$' if n <= 100 else f'$N=10^{np.log10(n):.0f}$' + if 'noise_frac' in data: + ax.set_title(f'$c$=0\\,\\%, {ns}', fontsize='medium') + else: + ax.set_title(f'$c$={100*alpha:g}\\,\\%, {ns}', fontsize='medium') + ax.set_xlim(0, 300) + ax.set_ylim(0, 300) + ax.set_xticks_delta(100) + ax.set_yticks_delta(100) + ax.set_xlabel('$f_1$', 'Hz') + ax.set_ylabel('$f_2$', 'Hz') + cax = ax.inset_axes([1.04, 0, 0.05, 1]) + cax.set_spines_outward('lrbt', 0) + if alpha == 0.1: + cb = fig.colorbar(pc, cax=cax, label=r'$|\chi_2|$ [Hz]') + else: + cb = fig.colorbar(pc, cax=cax) + cb.outline.set_color('none') + cb.outline.set_linewidth(0) + cax.set_yticks_delta(ten) + + +def plot_chi2_contrasts(axs, s, cell_name, n=None): + print(cell_name) + files, nums = sort_files(cell_name, + data_path.glob(f'chi2-split-{cell_name}-*.npz'), 1) + idx = -1 if n is None else nums.index(n) + plot_chi2(axs[0], s, files[idx]) + for k, alphastr in enumerate(['010', '030', '100']): + files, nums = sort_files(cell_name, + data_path.glob(f'chi2-noisen-{cell_name}-{alphastr}-*.npz'), 2) + idx = -1 if n is None else nums.index(n) + plot_chi2(axs[k + 1], s, files[idx]) + + +def plot_nli_diags(ax, s, data, alphax, alphay, xthresh, ythresh, cell_name): + datax = data[data('contrast') == alphax, :] + datay = data[data('contrast') == alphay, :] + nlix = datax('dnli') + nliy = datay('dnli100') + nfp = np.sum((nliy > ythresh) & (nlix < xthresh)) + ntp = np.sum((nliy > ythresh) & (nlix > xthresh)) + ntn = np.sum((nliy < ythresh) & (nlix < xthresh)) + nfn = np.sum((nliy < ythresh) & (nlix > xthresh)) + print(f' {ntp:2d} ({100*ntp/len(nlix):2.0f}%) true positive') + print(f' {nfp:2d} ({100*nfp/len(nlix):2.0f}%) false positive') + print(f' {ntn:2d} ({100*ntn/len(nlix):2.0f}%) true negative') + print(f' {nfn:2d} ({100*nfn/len(nlix):2.0f}%) false negative') + r, p = pearsonr(nlix, nliy) + l = linregress(nlix, nliy) + x = np.linspace(0, 10, 10) + ax.set_visible(True) + ax.set_title(f'$c$={100*alphay:g}\\,\\%', fontsize='medium') + ax.plot(x, x, **s.lsLine) + ax.plot(x, l.slope*x + l.intercept, **s.lsGrid) + ax.axhline(ythresh, **s.lsLine) + ax.axvline(xthresh, 0, 0.5, **s.lsLine) + if alphax == 0: + mask = datax('triangle') > 0.5 + ax.plot(nlix[mask], nliy[mask], zorder=30, label='strong', **s.psA1m) + mask = datax('border') > 0.5 + ax.plot(nliy[mask], nliy[mask], zorder=20, label='weak', **s.psA2m) + ax.plot(nlix, nliy, zorder=10, label='none', **s.psB1m) + # mark cell: + mask = datax('cell') == cell_name + color = s.psB1m['color'] + if alphax == 0: + if datax[mask, 'border']: + color = s.psA2m['color'] + elif datax[mask, 'triangle']: + color = s.psA1m['color'] + ax.plot(nlix[mask], nliy[mask], zorder=40, marker='o', + ms=s.psB1m['markersize'], mfc=color, mec='k', mew=0.8) + + box = dict(boxstyle='square,pad=0.1', fc='white', ec='none') + ax.text(1.0, 0.0, f'{ntn}', ha='right', fontsize='small', bbox=box) + ax.text(7.5, 0.0, f'{nfn}', ha='right', fontsize='small', bbox=box) + ax.text(1.0, 3.7, f'{nfp}', ha='right', fontsize='small', bbox=box) + ax.text(7.5, 3.7, f'{ntp}', ha='right', fontsize='small', bbox=box) + ax.set_ylim(0, 9) + ax.set_xlim(0, 9) + n = datax[0, 'nsegs'] + if alphax == 0: + ax.set_xlabel(f'SI, $c=0$, $N=10^{np.log10(n):.0f}$') + else: + ax.set_xlabel(f'SI, $N=10^{np.log10(n):.0f}$') + ax.set_ylabel('SI, $N=100$') + ax.set_xticks_delta(4) + ax.set_yticks_delta(4) + ax.set_minor_xticks_delta(1) + ax.set_minor_yticks_delta(1) + ax.text(0, 0.9, f'$r={r:.2f}$', transform=ax.transAxes, fontsize='small') + ax.text(0, 0.7, significance_str(p), transform=ax.transAxes, + fontsize='small') + if alphax == 0 and alphay == 0.01: + ax.legend(loc='upper left', bbox_to_anchor=(-1.5, 1), + title='triangle', handlelength=0.5, + handletextpad=0.5, labelspacing=0.2) + + kde = gaussian_kde(nliy, 0.15/np.std(nliy, ddof=1)) + nli = np.linspace(0, 8, 100) + pdf = kde(nli) + dax = ax.inset_axes([1.04, 0, 0.3, 1]) + dax.show_spines('') + dax.fill_betweenx(nli, pdf, **s.fsB1a) + dax.plot(pdf, nli, clip_on=False, **s.lsB1m) + + +def plot_summary_contrasts(axs, s, xthresh, ythresh, cell_name): + print(f'against contrast with thresholds: x={xthresh} and y={ythresh}') + data = TableData('summarychi2noise.csv') + for i, a in enumerate([0.01, 0.03, 0.1]): + print(f'contrast {100*a:2g}%:') + plot_nli_diags(axs[1 + i], s, data, a, a, xthresh, ythresh, cell_name) + print() + + +def plot_summary_diags(axs, s, xthresh, ythresh, cell_name): + print(f'against split with thresholds: x={xthresh} and y={ythresh}') + data = TableData('summarychi2noise.csv') + for i, a in enumerate([0.01, 0.03, 0.1]): + print(f'contrast {100*a:2g}%:') + plot_nli_diags(axs[1 + i], s, data, 0, a, xthresh, ythresh, cell_name) + + +if __name__ == '__main__': + xthresh = 1.2 + ythresh = 1.8 + s = plot_style() + fig, axs = plt.subplots(6, 4, cmsize=(s.plot_width, 0.85*s.plot_width), + height_ratios=[1, 1, 0, 1, 0, 1]) + fig.subplots_adjust(leftm=7, rightm=8, topm=2, bottomm=3.5, + wspace=1, hspace=1) + for ax in axs.flat: + ax.set_visible(False) + cell_name = '2012-12-21-ak-invivo-1' + plot_chi2_contrasts(axs[0], s, cell_name) + plot_chi2_contrasts(axs[1], s, cell_name, 10) + for k in range(2): + fig.common_yticks(axs[k, :]) + for k in range(4): + fig.common_xticks(axs[:2, k]) + plot_summary_contrasts(axs[3], s, xthresh, ythresh, cell_name) + plot_summary_diags(axs[5], s, xthresh, ythresh, cell_name) + fig.common_yticks(axs[3, 1:]) + fig.common_yticks(axs[5, 1:]) + fig.tag(axs, xoffs=-4.5, yoffs=1.8) + axs[1, 0].set_visible(False) + #plt.show() + fig.savefig() diff --git a/modelsusceptovern.py b/modelsusceptovern.py new file mode 100644 index 0000000..7a4d842 --- /dev/null +++ b/modelsusceptovern.py @@ -0,0 +1,165 @@ +import numpy as np +from scipy.stats import linregress +import matplotlib.pyplot as plt +from pathlib import Path +from plotstyle import plot_style, labels_params + + +data_path = Path('newdata3') + + +def sort_files(cell_name, all_files, n): + files = [fn for fn in all_files if '-'.join(fn.stem.split('-')[2:-n]) == cell_name] + if len(files) == 0: + return None, 0 + nums = [int(fn.stem.split('-')[-1]) for fn in files] + idxs = np.argsort(nums) + files = [files[i] for i in idxs] + nums = [nums[i] for i in idxs] + return files, nums + + +def plot_chi2(ax, s, data_file): + data = np.load(data_file) + n = data['n'] + alpha = data['alpha'] + freqs = data['freqs'] + pss = data['pss'] + dt_fix = 1 # 0.0005 + prss = np.abs(data['prss'])/dt_fix*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1)) + ax.set_visible(True) + ax.set_aspect('equal') + i0 = np.argmin(freqs < -300) + i0 = np.argmin(freqs < 0) + i1 = np.argmax(freqs > 300) + if i1 == 0: + i1 = len(freqs) + freqs = freqs[i0:i1] + prss = prss[i0:i1, i0:i1] + vmax = np.quantile(prss, 0.996) + ten = 10**np.floor(np.log10(vmax)) + for fac, delta in zip([1, 2, 3, 4, 6, 8, 10], + [0.5, 1, 1, 2, 3, 4, 5]): + if fac*ten >= vmax: + vmax = fac*ten + ten *= delta + break + pc = ax.pcolormesh(freqs, freqs, prss, vmin=0, vmax=vmax, + cmap='viridis', rasterized=True) + ax.set_title(f'$N=10^{np.log10(n):.0f}$', fontsize='medium') + ax.set_xlim(0, 300) + ax.set_ylim(0, 300) + ax.set_xticks_delta(300) + ax.set_minor_xticks(3) + ax.set_yticks_delta(300) + ax.set_minor_yticks(3) + ax.set_xlabel('$f_1$', 'Hz') + ax.set_ylabel('$f_2$', 'Hz') + cax = ax.inset_axes([1.04, 0, 0.05, 1]) + cax.set_spines_outward('lrbt', 0) + cb = fig.colorbar(pc, cax=cax) + cb.outline.set_color('none') + cb.outline.set_linewidth(0) + cax.set_yticks_delta(ten) + + +def plot_overn(ax, s, files, nmax=1e6, title=False): + ns = [] + stats = [] + for fname in files: + data = np.load(fname) + if not 'n' in data: + return + n = data['n'] + if nmax is not None and n > nmax: + continue + alpha = data['alpha'] + freqs = data['freqs'] + pss = data['pss'] + dt_fix = 1 # 0.0005 + chi2 = np.abs(data['prss'])/dt_fix*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1)) + ns.append(n) + i0 = np.argmin(freqs < 0) + i1 = np.argmax(freqs > 300) + if i1 == 0: + i1 = len(freqs) + chi2 = chi2[i0:i1, i0:i1] + stats.append(np.quantile(chi2, [0, 0.001, 0.05, 0.25, 0.5, + 0.75, 0.95, 0.998, 1.0])) + ns = np.array(ns) + stats = np.array(stats) + indx = np.argsort(ns) + ns = ns[indx] + stats = stats[indx] + ax.set_visible(True) + ax.plot(ns, stats[:, 7], '0.5', lw=1, zorder=50, label='99.8\\%') + ax.fill_between(ns, stats[:, 2], stats[:, 6], fc='0.85', zorder=40, label='5--95\\%') + ax.fill_between(ns, stats[:, 3], stats[:, 5], fc='0.5', zorder=45, label='25-75\\%') + ax.plot(ns, stats[:, 4], zorder=50, label='median', **s.lsSpine) + #ax.plot(ns, stats[:, 8], '0.0') + if title: + if 'noise_frac' in data: + ax.set_title('$c$=0\\,\\%', fontsize='medium') + else: + ax.set_title(f'$c$={100*alpha:g}\\,\\%', fontsize='medium') + ax.set_xlim(1e1, nmax) + ax.set_xscale('log') + ax.set_yscale('log') + ax.set_yticks_log(numticks=3) + if nmax > 1e6: + ax.set_ylim(3e-1, 5e3) + ax.set_minor_yticks_log(numticks=5) + ax.set_xticks_log(numticks=4) + ax.set_minor_xticks_log(numticks=8) + else: + ax.set_ylim(5e0, 1e4) + ax.set_minor_yticks_log(numticks=5) + ax.set_xticks_log(numticks=3) + ax.set_minor_xticks_log(numticks=6) + ax.set_xlabel('segments') + ax.set_ylabel('$|\\chi_2|$ [Hz]') + if alpha == 0.10: + ax.legend(loc='upper left', bbox_to_anchor=(1.4, 1.3), + markerfirst=False, title='$|\\chi_2|$ percentiles') + + +def plot_chi2_overn(axs, s, cell_name): + print(cell_name) + files, nums = sort_files(cell_name, + data_path.glob(f'chi2-split-{cell_name}-*.npz'), 1) + for k, n in enumerate([1e1, 1e2, 1e3, 1e6]): + plot_chi2(axs[k], s, files[nums.index(int(n))]) + plot_overn(axs[-1], s, files) + + +if __name__ == '__main__': + cells = ['2017-07-18-ai-invivo-1', # strong triangle + '2012-12-13-ao-invivo-1', # triangle + '2012-12-20-ac-invivo-1', # weak border triangle + '2013-01-08-ab-invivo-1'] # no triangle + s = plot_style() + fig, axs = plt.subplots(6, 6, cmsize=(s.plot_width, 0.9*s.plot_width), + width_ratios=[1, 1, 1, 1, 0, 1], + height_ratios=[1, 1, 1, 1, 0, 1]) + fig.subplots_adjust(leftm=8, rightm=0.5, topm=2, bottomm=3.5, + wspace=1, hspace=0.8) + for ax in axs.flat: + ax.set_visible(False) + for k in range(len(cells)): + plot_chi2_overn(axs[k], s, cells[k]) + cell_name = cells[0] + files, nums = sort_files(cell_name, + data_path.glob(f'chi2-split-{cell_name}-*.npz'), 1) + plot_overn(axs[-1, 0], s, files, 1e7, True) + for k, alphastr in enumerate(['010', '030', '100']): + files, nums = sort_files(cell_name, + data_path.glob(f'chi2-noisen-{cell_name}-{alphastr}-*.npz'), 2) + plot_overn(axs[-1, k + 1], s, files, 1e7, True) + for k in range(4): + fig.common_yticks(axs[k, :4]) + fig.common_xticks(axs[:4, k]) + fig.common_xticks(axs[:4, -1]) + fig.align_ylabels(axs[:4, -1], dist=12) + fig.common_yticks(axs[-1, :4]) + fig.tag(axs, xoffs=-2.5, yoffs=1.8) + fig.savefig() diff --git a/plotstyle.py b/plotstyle.py new file mode 100644 index 0000000..ec05ad3 --- /dev/null +++ b/plotstyle.py @@ -0,0 +1,126 @@ +import matplotlib as mpl +import plottools.plottools as pt +from plottools.spines import spines_params +from plottools.labels import labels_params +from plottools.colors import lighter, darker + + +def significance_str(p): + if p > 0.05: + return f'$p={p:.2f}$' + elif p > 0.01: + return '$p<0.05$' + elif p > 0.001: + return '$p<0.01$' + else: + return '$p<0.001$' + + +def plot_style(): + palette = pt.palettes['muted'] + lwthick = 1.0 + lwthin = 0.5 + lwspines = 1.0 + names = ['A1', 'A2', 'A3', + 'B1', 'B2', 'B3', 'B4', + 'C1', 'C2', 'C3', 'C4'] + colors = [palette['red'], palette['orange'], palette['yellow'], + palette['blue'], palette['purple'], palette['magenta'], palette['lightblue'], + palette['lightgreen'], palette['green'], palette['darkgreen'], palette['cyan']] + dashes = ['-', '-', '-', + '-', '-', '-', '-.', + '-', '-', '-', '-'] + markers = [('o', 1.0), ('p', 1.1), ('h', 1.1), + ((3, 1, 60), 1.25), ((3, 1, 0), 1.25), ((3, 1, 90), 1.25), ((3, 1, 30), 1.25), + ('s', 0.9), ('D', 0.85), ('*', 1.6), ((4, 1, 45), 1.4)] + class ns: pass + ns.colors = palette + ns.lwthick = lwthick + ns.lwthin = lwthin + ns.plot_width = 16.5 + pt.make_linepointfill_styles(ns, names, colors, dashes, markers, + lwthick=lwthick, lwthin=lwthin, + markerlarge=5, markersmall=4, + mec=0.0, mew=0.5, fillalpha=0.4) + pt.make_line_styles(ns, 'ls', 'Spine', '', palette['black'], '-', + lwspines, clip_on=False) + pt.make_line_styles(ns, 'ls', 'Grid', '', palette['gray'], '--', + 0.7*lwthin) + pt.make_line_styles(ns, 'ls', 'Dotted', '', palette['gray'], ':', + 0.7*lwthin) + pt.make_line_styles(ns, 'ls', 'Marker', '', palette['black'], '-', + lwthick, clip_on=False) + pt.make_line_styles(ns, 'ls', 'Line', '', palette['black'], '-', + lwthin) + pt.make_line_styles(ns, 'ls', 'EOD', '', palette['gray'], '-', lwthin) + pt.make_line_styles(ns, 'ls', 'AM', '', palette['red'], '-', lwthick) + ns.model_color1 = palette['purple'] + ns.model_color2 = lighter(ns.model_color1, 0.6) + ns.punit_color1 = palette['blue'] + ns.punit_color2 = lighter(ns.punit_color1, 0.6) + ns.ampul_color1 = palette['green'] + ns.ampul_color2 = lighter(ns.ampul_color1, 0.6) + pt.make_line_styles(ns, 'ls', 'M%d', '', [ns.model_color1, ns.model_color2], '-', lwthick) + pt.make_point_styles(ns, 'ps', 'M%d', '', [ns.model_color1, ns.model_color2], '-', markers=('o', 1), markersizes=4) + pt.make_line_styles(ns, 'ls', 'P%d', '', [ns.punit_color1, ns.punit_color2], '-', lwthick) + pt.make_point_styles(ns, 'ps', 'P%d', '', [ns.punit_color1, ns.punit_color2], '-', markers=('o', 1), markersizes=4) + pt.make_line_styles(ns, 'ls', 'A%d', '', [ns.ampul_color1, ns.ampul_color2], '-', lwthick) + pt.make_point_styles(ns, 'ps', 'A%d', '', [ns.ampul_color1, ns.ampul_color2], '-', markers=('o', 1), markersizes=4) + pt.make_line_styles(ns, 'ls', 'Diag', '', palette['white'], '--', lwthin) + pt.arrow_style(ns, 'Line', dist=3.0, style='>', shrink=0, lw=0.6, + color=palette['black'], head_length=4, head_width=4, + bbox=dict(boxstyle='round,pad=0.1', facecolor='white', + edgecolor='none', alpha=1.0)) + pt.arrow_style(ns, 'Hertz', dist=3.0, style='>', shrink=0, lw=0.6, + color=palette['black'], head_length=3, head_width=3, + heads='<>', text='%.0f\u2009Hz', rotation='vertical', + fontsize='x-small', + bbox=dict(boxstyle='round,pad=0.1', facecolor='white', + edgecolor='none', alpha=0.6)) + pt.arrow_style(ns, 'Point', dist=3.0, style='>>', shrink=0, + lw=1.1, color=palette['black'], head_length=5, + head_width=4) + pt.arrow_style(ns, 'PointSmall', dist=3.0, style='>>', shrink=0, + lw=1, color=palette['black'], head_length=5, + head_width=5, fontsize='small') + pt.arrow_style(ns, 'Marker', dist=3.0, style='>>', shrink=0, + lw=0.9, color=palette['black'], head_length=5, + head_width=5, fontsize='small', ha='center', + va='center', + bbox=dict(boxstyle='round, pad=0.1, rounding_size=0.4', + facecolor=palette['white'], + edgecolor='none', alpha=0.6)) + + # rc settings: + mpl.rcdefaults() + pt.axes_params(0.0, 0.0, 0.0, color='none') + cmcolors = [pt.lighter(palette['yellow'], 0.2), + pt.lighter(palette['orange'], 0.5), palette['orange'], + palette['red'], palette['red']] + cmvalues = [0.0, 0.3, 0.7, 0.95, 1.0] + pt.colormap('YR', cmcolors, cmvalues) + cycle_colors = ['blue', 'red', 'orange', 'lightgreen', 'magenta', + 'yellow', 'cyan', 'pink'] + pt.colors_params(palette, cycle_colors, 'RdYlBu') + pt.figure_params(palette['white'], format='pdf', compression=6, + fonttype=3, stripfonts=True) + pt.labels_params('{label} [{unit}]', labelsize='medium', labelpad=6) + pt.legend_params(fontsize='small', frameon=False, borderpad=0.0, + handlelength=1.5, handletextpad=1, + numpoints=1, scatterpoints=1, + labelspacing=0.5, columnspacing=0.5) + pt.scalebars_params(format_large='%.0f', format_small='%.1f', + lw=2.2, color=palette['black'], capsize=0, + clw=0.5, font=dict(fontsize='medium', + fontstyle='normal')) + pt.spines_params(spines='lb', spines_offsets={'lrtb': 3}, + spines_bounds={'lrtb': 'full'}) + pt.tag_params(xoffs='auto', yoffs='auto', label='%A', + minor_label=r'%A$_{\text{%mi}}$', + font=dict(fontsize='x-large', fontstyle='normal', + fontweight='normal')) + pt.text_params(font_size=7, font_family='sans-serif', latex=True, + preamble=('p:sfmath', 'p:marvosym', 'p:xfrac', + 'p:SIunits')) + pt.ticks_params(xtick_dir='out', xtick_size=3) + return ns diff --git a/punitexamplecell.py b/punitexamplecell.py new file mode 100644 index 0000000..fab6905 --- /dev/null +++ b/punitexamplecell.py @@ -0,0 +1,275 @@ +import sys +sys.path.insert(0, 'ephys') # for analysing data +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +from spectral import diag_projection, peakedness +from plotstyle import plot_style + + +cell_name = '2020-10-27-ag-invivo-1' +run1 = 0 +run2 = 1 + +base_path = Path('ephys') +data_path = base_path / 'data' +results_path = base_path / 'results' + + +def load_baseline(path, cell_name): + d = path / f'{cell_name}-baseline.npz' + data = np.load(d) + ['eodf', 'isis', 'isih', 'lags', 'corrs', 'freqs', 'prr'] + eodf = float(data['eodf']) + rate = float(data['ratebase/Hz']) + cv = float(data['cvbase']) + isis = data['isis'] + pdf = data['isih'] + freqs = data['freqs'] + prr = data['prr'] + return eodf, rate, cv, isis, pdf, freqs, prr + + +def load_noise(path, cell_name, run): + data = np.load(path / f'{cell_name}-spectral-data-s{run:02d}.npz') + contrast = data['contrast'] + time = data['time'] + stimulus = data['stimulus'] + name = str(data['stimulus_name']) + fcutoff = float(name.lower().replace('blwn', '').replace('inputarr_', '').replace('gwn', '').split('h')[0]) + spikes = [] + for k in range(1000): + key = f'spikes_{k:03d}' + if not key in data.keys(): + break + spikes.append(data[key]) + return contrast, time, stimulus, spikes + + +def load_spectra(path, cell_name, run=None): + if run is None: + data = np.load(cell_name) + else: + d = list(path.glob(f'{cell_name}-spectral*-s{run:02d}.npz')) + data = np.load(d[0]) + contrast = float(data['alpha']) + fcutoff = float(data['fcutoff']) + freqs = data['freqs'] + pss = data['pss'] + prs = data['prs'] + prss = data['prss'] + nsegs = int(data['n']) + gain = np.abs(prs)/pss + chi2 = np.abs(prss)*0.5/np.sqrt(pss.reshape(1, -1)*pss.reshape(-1, 1)) + return fcutoff, contrast, freqs, gain, chi2 + + +def plot_isih(ax, s, rate, cv, isis, pdf): + ax.show_spines('b') + ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1) + ax.set_xlim(0, 8) + ax.set_xticks_delta(2) + ax.set_xlabel('ISI', 'ms') + ax.text(0, 1.08, 'P-unit:', transform=ax.transAxes, color=s.cell_color1, + fontsize='large') + ax.text(0.6, 1.08, f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}', + transform=ax.transAxes) + + +def plot_response_spectrum(ax, s, eodf, rate, freqs, prr): + rate_i = np.argmax(prr[freqs < 0.7*eodf]) + eod_i = np.argmax(prr[freqs > 500]) + np.argmax(freqs > 500) + power_db = 10*np.log10(prr/np.max(prr)) + ax.show_spines('b') + mask = freqs < 890 + ax.plot(freqs[mask], power_db[mask], **s.lsC1) + ax.plot(freqs[eod_i], power_db[eod_i], **s.psA1c) + ax.plot(freqs[rate_i], power_db[rate_i], **s.psA2c) + ax.set_xlim(0, 900) + ax.set_ylim(-25, 5) + ax.set_xticks_delta(300) + ax.set_xlabel('$f$', 'Hz') + ax.text(freqs[eod_i], power_db[eod_i] + 2, '$f_{\\rm EOD}$', + ha='center') + ax.text(freqs[rate_i], power_db[rate_i] + 2, '$r$', + ha='center') + ax.yscalebar(1.05, 0, 10, 'dB', ha='right') + + +def plot_response(ax, s, eodf, time1, stimulus1, contrast1, spikes1, contrast2, spikes2): + t0 = 0.3 + t1 = 0.4 + #print(len(spikes1), len(spikes2)) + maxtrials = 8 + trials = np.arange(maxtrials) + ax.show_spines('') + ax.eventplot(spikes1[2:2+maxtrials], lineoffsets=trials - maxtrials + 1, + linelength=0.8, linewidths=1, color=s.cell_color1) + ax.eventplot(spikes2[2:2+maxtrials], lineoffsets=trials - 2*maxtrials, + linelength=0.8, linewidths=1, color=s.cell_color2) + am = 1 + contrast1*stimulus1 + eod = np.sin(2*np.pi*eodf*time1) * am + ax.plot(time1, 4*eod + 7, **s.lsEOD) + ax.plot(time1, 4*am + 7, **s.lsAM) + ax.set_xlim(t0, t1) + ax.set_ylim(-2*maxtrials - 0.5, 14) + ax.xscalebar(1, -0.05, 0.01, None, '10\\,ms', ha='right') + ax.text(t1 + 0.003, -0.5*maxtrials, f'${100*contrast1:.0f}$\\,\\%', + va='center', color=s.cell_color1) + ax.text(t1 + 0.003, -1.55*maxtrials, f'${100*contrast2:.0f}$\\,\\%', + va='center', color=s.cell_color2) + + +def plot_gain(ax, s, contrast1, freqs1, gain1, contrast2, freqs2, gain2, fcutoff): + ax.plot(freqs2, gain2, label=f'{100*contrast2:.0f}', **s.lsC2) + ax.plot(freqs1, gain1, label=f'{100*contrast1:.0f}', **s.lsC1) + ax.set_xlim(0, fcutoff) + ax.set_ylim(0, 800) + ax.set_xticks_delta(100) + ax.set_xlabel('$f$', 'Hz') + ax.set_ylabel(r'$|\chi_1|$', 'Hz') + + +def plot_colorbar(ax, pc, dc=None): + cax = ax.inset_axes([1.04, 0, 0.05, 1]) + cax.set_spines_outward('lrbt', 0) + cb = cax.get_figure().colorbar(pc, cax=cax, label=r'$|\chi_2|$ [kHz]') + cb.outline.set_color('none') + cb.outline.set_linewidth(0) + if dc is not None: + cax.set_yticks_delta(dc) + + +def plot_chi2(ax, s, contrast, freqs, chi2, fcutoff, vmax): + ax.set_aspect('equal') + if vmax is None: + vmax = np.quantile(1e-3*chi2, 0.99) + pc = ax.pcolormesh(freqs, freqs, 1e-3*chi2, vmin=0, vmax=vmax, + cmap='viridis', rasterized=True, zorder=10) + ax.set_xlim(0, fcutoff) + ax.set_ylim(0, fcutoff) + df = 100 if fcutoff == 300 else 50 + ax.set_xticks_delta(df) + ax.set_yticks_delta(df) + ax.set_xlabel('$f_1$', 'Hz') + ax.set_ylabel('$f_2$', 'Hz') + return pc + + +def plot_diagonals(ax, s, fbase, contrast1, freqs1, chi21, contrast2, freqs2, chi22, fcutoff): + diags = [] + nlis = [] + nlips = [] + nlifs = [] + for contrast, freqs, chi2 in [[contrast1, freqs1, chi21], [contrast2, freqs2, chi22]]: + dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff) + diags.append([dfreqs, diag]) + nli, nlif = peakedness(dfreqs, diag, fbase, median=False) + nlip = diag[np.argmin(np.abs(dfreqs - nlif))] + nlis.append(nli) + nlips.append(nlip) + nlifs.append(nlif) + print(f' SI at {100*contrast:.1f}% contrast: {nli:.2f}') + ax.axvline(fbase, **s.lsGrid) + ax.plot(diags[1][0], 1e-3*diags[1][1], **s.lsC2) + ax.plot(diags[0][0], 1e-3*diags[0][1], **s.lsC1) + ax.plot(nlifs[1], 1e-3*nlips[1], **s.psC2) + ax.plot(nlifs[0], 1e-3*nlips[0], **s.psC1) + ax.set_xlim(0, 2*fcutoff) + ax.set_ylim(0, 4.2) + ax.set_xticks_delta(300) + ax.set_yticks_delta(1) + ax.set_xlabel('$f_1 + f_2$', 'Hz') + #ax.set_ylabel(r'$|\chi_2|$', 'kHz') + ax.text(nlifs[1] - 50, 1e-3*nlips[1], f'{100*contrast2:.0f}\\%', + ha='right') + ax.text(nlifs[1] + 70, 1e-3*nlips[1], f'SI={nlis[1]:.1f}') + ax.text(nlifs[0] - 50, 1e-3*nlips[0], f'{100*contrast1:.0f}\\%', + ha='right') + ax.text(nlifs[0] + 70, 1e-3*nlips[0], f'SI={nlis[0]:.1f}') + ax.text(fbase, 4.3, '$r$', ha='center') + + +if __name__ == '__main__': + print('Example P-unit:', cell_name) + eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(results_path, cell_name) + print(f' baseline firing rate: {rate:.0f}Hz') + print(f' baseline firing CV : {cv:.2f}') + contrast1, time1, stimulus1, spikes1 = load_noise(data_path, cell_name, run1) + contrast2, time2, stimulus2, spikes2 = load_noise(data_path, cell_name, run2) + fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(results_path, cell_name, run1) + fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(results_path, cell_name, run2) + + s = plot_style() + s.cell_color1 = s.punit_color1 + s.cell_color2 = s.punit_color2 + s.lsC1 = s.lsP1 + s.lsC2 = s.lsP2 + s.psC1 = s.psP1 + s.psC2 = s.psP2 + fig, (ax1, ax2, ax3) = plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.5, 3], + cmsize=(s.plot_width, 0.8*s.plot_width)) + fig.subplots_adjust(leftm=8, rightm=9, topm=2, bottomm=4, + wspace=0.4, hspace=0.5) + axi, axp, axr = ax1.subplots(1, 3, width_ratios=[2, 3, 0, 10]) + axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.4) + axg = axg.subplots(1, 1, width_ratios=[1, 0.1]) + axd = axd.subplots(1, 1, width_ratios=[0.2, 1]) + axs = ax3.subplots(1, 4, wspace=0.4) + + plot_isih(axi, s, rate, cv, isis, pdf) + plot_response_spectrum(axp, s, eodf, rate, freqs, prr) + plot_response(axr, s, eodf, time1, stimulus1, contrast1, spikes1, + contrast2, spikes2) + + plot_gain(axg, s, contrast1, freqs1, gain1, + contrast2, freqs2, gain2, fcutoff1) + pc = plot_chi2(axc1, s, contrast2, freqs2, chi22, fcutoff2, 4) + axc1.plot([0, fcutoff2], [0, fcutoff2], zorder=20, **s.lsDiag) + axc1.set_title(f'$c$={100*contrast2:g}\\,\\%', + fontsize='medium', color=s.cell_color2) + pc = plot_chi2(axc2, s, contrast1, freqs1, chi21, fcutoff1, 4) + axc2.set_title(f'$c$={100*contrast1:g}\\,\\%', + fontsize='medium', color=s.cell_color1) + axc2.plot([0, fcutoff1], [0, fcutoff1], zorder=20, **s.lsDiag) + plot_colorbar(axc2, pc) + plot_diagonals(axd, s, rate, contrast1, freqs1, chi21, + contrast2, freqs2, chi22, fcutoff1) + + fig.common_yticks(axc1, axc2) + fig.tag([axi, axp, axr], xoffs=-3, yoffs=-1) + fig.tag([axg, axc1, axc2, axd], xoffs=-3, yoffs=2) + + print('Additional example cells:') + example_cells = [ + ['2021-06-18-ae-invivo-1', 3], # 98Hz, 1%, ok + ['2012-03-30-ah', 2], # 177Hz, 2.5%, 2.0, nice + ##['2012-07-03-ak', 0], # 120Hz, 2.5%, 1.8, broader + ##['2012-12-20-ac', 0], # 213Hz, 2.5%, 2.1, ok + #['2017-07-18-ai-invivo-1', 1], # 78Hz, 5%, 2.3, weak + ##['2019-06-28-ae', 0], # 477Hz, 10%, 2.6, weak + ##['2020-10-27-aa-invivo-1', 4], # 259Hz, 0.5%, 2.0, ok + ##['2020-10-27-ae-invivo-1', 4], # 375Hz, 0.5%, 4.3, nice, additional low freq line + ###['2020-10-27-ag-invivo-1', 2], # 405Hz, 5%, 3.9, strong, is already the example + ##['2021-08-03-ab-invivo-1', 1], # 140Hz, 0.5%, ok + ['2020-10-29-ag-invivo-1', 2], # 164Hz, 5%, 1.6, no diagonal + ##['2010-08-31-ag', 1], # 269Hz, 5%, no diagonal + ['2018-08-24-ak', 1], # 145Hz, 5%, no diagonal + ##['2018-08-29-af', 1], # 383Hz, 5%, no diagonal + ] + for k, (cell, run) in enumerate(example_cells): + eodf, rate, cv, _, _, _, _ = load_baseline(results_path, cell) + fcutoff, contrast, freqs, gain, chi2 = load_spectra(results_path, cell, run) + dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff) + nli, nlif = peakedness(dfreqs, diag, rate, median=False) + print(f' {cell:<22s}: run={run:2d}, fbase={rate:3.0f}Hz, CV={cv:.2f}, SI={nli:3.1f}') + pc = plot_chi2(axs[k], s, contrast, freqs, chi2, fcutoff, 1.3) + axs[k].set_title(f'$r={rate:.0f}$Hz, CV$_{{\\rm base}}$={cv:.2f}', fontsize='medium') + axs[k].text(0.95, 0.9, f'SI($r$)={nli:.1f}', ha='right', zorder=50, + color='white', fontsize='medium', + transform=axs[k].transAxes) + plot_colorbar(axs[-1], pc) + fig.common_yticks(axs) + fig.tag(axs, xoffs=-3, yoffs=2) + + fig.savefig() diff --git a/regimes.py b/regimes.py new file mode 100644 index 0000000..d0b8ad9 --- /dev/null +++ b/regimes.py @@ -0,0 +1,349 @@ +import os +import numpy as np +from scipy.stats import linregress +import matplotlib.pyplot as plt +from numba import jit +from thunderlab.tabledata import TableData +from plotstyle import plot_style, lighter, darker + + +def load_models(file): + """ Load model parameter from csv file. + + Parameters + ---------- + file: string + Name of file with model parameters. + + Returns + ------- + parameters: list of dict + For each cell a dictionary with model parameters. + """ + parameters = [] + with open(file, 'r') as file: + header_line = file.readline() + header_parts = header_line.strip().split(",") + keys = header_parts + for line in file: + line_parts = line.strip().split(",") + parameter = {} + for i in range(len(keys)): + parameter[keys[i]] = float(line_parts[i]) if i > 0 else line_parts[i] + parameters.append(parameter) + return parameters + + +def cell_parameters(parameters, cell_name): + for params in parameters: + if params['cell'] == cell_name: + return params + print('cell', cell_name, 'not found!') + exit() + return None + + +@jit(nopython=True) +def simulate(stimulus, deltat=0.00005, v_zero=0.0, a_zero=2.0, + threshold=1.0, v_base=0.0, delta_a=0.08, tau_a=0.1, + v_offset=-10.0, mem_tau=0.015, noise_strength=0.05, + input_scaling=60.0, dend_tau=0.001, ref_period=0.001): + """ Simulate a P-unit. + + Returns + ------- + spike_times: 1-D array + Simulated spike times in seconds. + """ + # initial conditions: + v_dend = stimulus[0] + v_mem = v_zero + adapt = a_zero + + # prepare noise: + noise = np.random.randn(len(stimulus)) + noise *= noise_strength / np.sqrt(deltat) + + # rectify stimulus array: + stimulus = stimulus.copy() + stimulus[stimulus < 0.0] = 0.0 + + # integrate: + spike_times = [] + for i in range(len(stimulus)): + v_dend += (-v_dend + stimulus[i]) / dend_tau * deltat + v_mem += (v_base - v_mem + v_offset + ( + v_dend * input_scaling) - adapt + noise[i]) / mem_tau * deltat + adapt += -adapt / tau_a * deltat + + # refractory period: + if len(spike_times) > 0 and (deltat * i) - spike_times[-1] < ref_period + deltat/2: + v_mem = v_base + + # threshold crossing: + if v_mem > threshold: + v_mem = v_base + spike_times.append(i * deltat) + adapt += delta_a / tau_a + + return np.array(spike_times) + + +def punit_spikes(parameter, alpha, beatf1, beatf2, tmax, trials): + tini = 0.2 + model_params = dict(parameter) + cell = model_params.pop('cell') + eodf0 = model_params.pop('EODf') + time = np.arange(-tini, tmax, model_params['deltat']) + stimulus = np.sin(2*np.pi*eodf0*time) + stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf1)*time) + stimulus += alpha*np.sin(2*np.pi*(eodf0 + beatf2)*time) + spikes = [] + for i in range(trials): + model_params['v_zero'] = np.random.rand() + model_params['a_zero'] += 0.02*parameter['a_zero']*np.random.randn() + spiket = simulate(stimulus, **model_params) + spikes.append(spiket[spiket > tini] - tini) + return spikes + + +def plot_am(ax, s, alpha, beatf1, beatf2, tmax): + time = np.arange(0, tmax, 0.0001) + am = alpha*np.sin(2*np.pi*beatf1*time) + am += alpha*np.sin(2*np.pi*beatf2*time) + ax.show_spines('l') + ax.plot(1000*time, -100*am, **s.lsStim) + ax.set_xlim(0, 1000*tmax) + ax.set_ylim(-50, 50) + #ax.set_xlabel('Time', 'ms') + ax.set_ylabel('AM', r'\%') + ax.text(1, 1.2, f'Contrast = {100*alpha:g}\\,\\%', + transform=ax.transAxes, ha='right') + + +def plot_raster(ax, s, spikes, tmax): + spikes_ms = [1000*s[s<tmax] for s in spikes[:16]] + ax.show_spines('') + ax.eventplot(spikes_ms, linelengths=0.9, **s.lsRaster) + ax.set_xlim(0, 1000*tmax) + #ax.set_xlabel('Time', 'ms') + #ax.set_ylabel('Trials') + + +def compute_power(spikes, nfft, dt): + psds = [] + time = np.arange(nfft + 1)*dt + tmax = nfft*dt + rates = [] + cvs = [] + for s in spikes: + rates.append(len(s)/tmax) + isis = np.diff(s) + cvs.append(np.std(isis)/np.mean(isis)) + b, _ = np.histogram(s, time) + fourier = np.fft.rfft(b - np.mean(b)) + psds.append(np.abs(fourier)**2) + #psds.append(fourier) + freqs = np.fft.rfftfreq(nfft, dt) + #print('mean rate', np.mean(rates)) + #print('CV', np.mean(cvs)) + return freqs, np.mean(psds, 0) + #return freqs, np.abs(np.mean(psds, 0))**2/dt + + +def decibel(x): + return 10*np.log10(x/1e8) + +def plot_psd(ax, s, spikes, nfft, dt, beatf1, beatf2): + offs = 3 + freqs, psd = compute_power(spikes, nfft, dt) + psd /= freqs[1] + ax.plot(freqs, decibel(psd), **s.lsPower) + ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs, + label=r'$f_{\rm base}$', clip_on=False, **s.psF0) + ax.plot(beatf1, decibel(peak_ampl(freqs, psd, beatf1)) + offs, + label=r'$\Delta f_1$', clip_on=False, **s.psF01) + ax.plot(beatf2, decibel(peak_ampl(freqs, psd, beatf2)) + offs + 4.5, + label=r'$\Delta f_2$', clip_on=False, **s.psF02) + ax.plot(beatf2 - beatf1, decibel(peak_ampl(freqs, psd, beatf2 - beatf1)) + offs, + label=r'$\Delta f_2 - \Delta f_1$', clip_on=False, **s.psF01_2) + ax.plot(beatf1 + beatf2, decibel(peak_ampl(freqs, psd, beatf1 + beatf2)) + offs, + label=r'$\Delta f_1 + \Delta f_2$', clip_on=False, **s.psF012) + ax.set_xlim(0, 300) + ax.set_ylim(-40, 0) + ax.set_xlabel('Frequency', 'Hz') + ax.set_ylabel('Power [dB]') + + +def plot_example(axs, axr, axp, s, cell, alpha, beatf1, beatf2, nfft, trials): + dt = 0.0001 + tmax = nfft*dt + t1 = 0.1 + spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials) + plot_am(axs, s, alpha, beatf1, beatf2, t1) + plot_raster(axr, s, spikes, t1) + plot_psd(axp, s, spikes, nfft, dt, beatf1, beatf2) + + +def peak_ampl(freqs, psd, f): + df = 2 + psd_snippet = psd[(freqs > f - df) & (freqs < f + df)] + return np.max(psd_snippet) + + +def compute_peaks(name, cell, alpha_max, beatf1, beatf2, nfft, trials): + file_name = f'{name}-contrastpeaks.csv' + if os.path.exists(file_name): + data = TableData(file_name) + return data + dt = 0.0001 + tmax = nfft*dt + alphas = np.linspace(0, alpha_max, 200) + ampl_f1 = np.zeros(len(alphas)) + ampl_f2 = np.zeros(len(alphas)) + ampl_sum = np.zeros(len(alphas)) + ampl_diff = np.zeros(len(alphas)) + for k, alpha in enumerate(alphas): + print(alpha) + spikes = punit_spikes(cell, alpha, beatf1, beatf2, tmax, trials) + freqs, psd = compute_power(spikes, nfft, dt) + ampl_f1[k] = peak_ampl(freqs, psd, beatf1) + ampl_f2[k] = peak_ampl(freqs, psd, beatf2) + ampl_sum[k] = peak_ampl(freqs, psd, beatf1 + beatf2) + ampl_diff[k] = peak_ampl(freqs, psd, beatf2 - beatf1) + data = TableData() + data.append('contrast', '%', '%.1f', 100*alphas) + data.append('f1', 'Hz', '%g', ampl_f1) + data.append('f2', 'Hz', '%g', ampl_f2) + data.append('f1+f2', 'Hz', '%g', ampl_sum) + data.append('f2-f1', 'Hz', '%g', ampl_diff) + data.write(file_name) + return data + + +def amplitude(power): + power -= power[0] + power[power<0] = 0 + return np.sqrt(power) + + +def amplitude_linearfit(contrast, power, max_contrast): + power -= power[0] + power[power<0] = 0 + ampl = np.sqrt(power) + a = ampl[contrast <= max_contrast] + c = contrast[contrast <= max_contrast] + r = linregress(c, a) + return r.intercept + r.slope*contrast + + +def amplitude_squarefit(contrast, power, max_contrast): + power -= power[0] + power[power<0] = 0 + ampl = np.sqrt(power) + a = np.sqrt(ampl[contrast <= max_contrast]) + c = contrast[contrast <= max_contrast] + r = linregress(c, a) + return (r.intercept + r.slope*contrast)**2 + + +def plot_peaks(ax, s, data, alphas): + contrast = data[:, 'contrast'] + ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f1'], 4), **s.lsF01m) + ax.plot(contrast, amplitude_linearfit(contrast, data[:, 'f2'], 2), **s.lsF02m) + ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f1+f2'], 4), **s.lsF012m) + ax.plot(contrast, amplitude_squarefit(contrast, data[:, 'f2-f1'], 4), **s.lsF01_2m) + ax.plot(contrast, amplitude(data[:, 'f1']), **s.lsF01) + ax.plot(contrast, amplitude(data[:, 'f2']), **s.lsF02) + ax.plot(contrast, amplitude(data[:, 'f1+f2']), **s.lsF012) + ax.plot(contrast, amplitude(data[:, 'f2-f1']), **s.lsF01_2) + for alpha, tag in zip(alphas, ['A', 'B', 'C', 'D']): + contrast = 100*alpha + ax.plot(contrast, 630, 'vk', ms=4, clip_on=False) + ax.text(contrast, 660, tag, ha='center') + #ax.axvline(contrast, **s.lsGrid) + #ax.text(contrast, 630, tag, ha='center') + ax.axvline(1.5, **s.lsLine) + ax.axvline(4, **s.lsLine) + yoffs = 340 + ax.text(1.5/2, yoffs, 'linear\nregime', + ha='center', va='center') + ax.text((1.5 + 4)/2, yoffs, 'weakly\nnonlinear\nregime', + ha='center', va='center') + ax.text(10, yoffs, 'strongly\nnonlinear\nregime', + ha='center', va='center') + ax.set_xlim(0, 16.5) + ax.set_ylim(0, 600) + ax.set_xticks_delta(5) + ax.set_yticks_delta(300) + ax.set_xlabel('Contrast', r'\%') + ax.set_ylabel('Amplitude', 'Hz') + + +if __name__ == '__main__': + parameters = load_models('models.csv') + cell_name = '2013-01-08-aa-invivo-1' # 138Hz, CV=0.26: perfect! + beatf1 = 40 + beatf2 = 138 + # cell_name = '2012-07-03-ak-invivo-1' # 128Hz, CV=0.24 + # cell_name = '2018-05-08-ae-invivo-1' # 142Hz, CV=0.48 + + """ + parameters = load_models('models_big_fit_d_right.csv') + cell_name = '2013-01-08-aa-invivo-1' # 131Hz, CV=0.04: wrong! + beatf1 = 30 + beatf2 = 132 + """ + + cell = cell_parameters(parameters, cell_name) + for k in cell: + print(k, cell[k]) + + s = plot_style() + s.lwmid = 1.0 + s.lwthick = 1.6 + s.lsStim = dict(color='gray', lw=s.lwmid) + s.lsRaster = dict(color='black', lw=s.lwthin) + s.lsPower = dict(color='gray', lw=s.lwmid) + s.lsF0 = dict(color='blue', lw=s.lwthick) + s.lsF01 = dict(color='green', lw=s.lwthick) + s.lsF02 = dict(color='purple', lw=s.lwthick) + s.lsF012 = dict(color='orange', lw=s.lwthick) + s.lsF01_2 = dict(color='red', lw=s.lwthick) + s.lsF0m = dict(color=lighter('blue', 0.5), lw=s.lwthin) + s.lsF01m = dict(color=lighter('green', 0.6), lw=s.lwthin) + s.lsF02m = dict(color=lighter('purple', 0.5), lw=s.lwthin) + s.lsF012m = dict(color=darker('orange', 0.9), lw=s.lwthin) + s.lsF01_2m = dict(color=darker('red', 0.9), lw=s.lwthin) + + s.psF0 = dict(color='blue', marker='o', linestyle='none', markersize=5, mec='none', mew=0) + s.psF01 = dict(color='green', marker='o', linestyle='none', markersize=5, mec='none', mew=0) + s.psF02 = dict(color='purple', marker='o', linestyle='none', markersize=5, mec='none', mew=0) + s.psF012 = dict(color='orange', marker='o', linestyle='none', markersize=5, mec='none', mew=0) + s.psF01_2 = dict(color='red', marker='o', linestyle='none', markersize=5, mec='none', mew=0) + + nfft = 2**18 + fig, axs = plt.subplots(5, 4, cmsize=(s.plot_width, 0.8*s.plot_width), + height_ratios=[1, 1.5, 2, 1.5, 4]) + fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5, + wspace=0.3, hspace=0.3) + ax0 = fig.merge(axs[3, :]) + ax0.set_visible(False) + axa = fig.merge(axs[4, :]) + fig.show_spines('lb') + alphas = [0.01, 0.03, 0.05, 0.16] + #alphas = [0.002, 0.01, 0.05, 0.1] + for c, alpha in enumerate(alphas): + plot_example(axs[0, c], axs[1, c], axs[2, c], s, cell, + alpha, beatf1, beatf2, nfft, 100) + axs[1, 0].xscalebar(1, -0.1, 30, 'ms', ha='right') + axs[2, 0].legend(loc='center left', bbox_to_anchor=(0, -0.8), + ncol=5, columnspacing=2) + data = compute_peaks(cell_name, cell, 0.2, beatf1, beatf2, nfft, 1000) + plot_peaks(axa, s, data, alphas) + fig.common_yspines(axs[0, :]) + fig.common_yticks(axs[2, :]) + #fig.common_xlabels(axs[2, :]) + fig.tag(axs[0, :], xoffs=-2, yoffs=1.6) + fig.tag(axa) + fig.savefig()