import numpy as np import matplotlib.pyplot as plt from pathlib import Path from spectral import diag_projection, peak_size from plotstyle import plot_style from plotstyle import plot_chi2 example_cell = [['2020-10-27-ag-invivo-1', 0], ['2020-10-27-ag-invivo-1', 1]] example_cells = [ #['2021-06-18-ae-invivo-1', 3], # 98Hz, 1%, ok ['2021-06-18-ae-invivo-1', 6], # 98Hz, 2: 10%, ok OR 6: 5% #['2012-03-30-ah', 5], # 177Hz, 5%, 2.0, nice ##['2012-07-03-ak', 0], # 120Hz, 2.5%, 1.8, broader, the one model cell, nice triangle up to 1%! ##['2012-12-20-ac', 0], # 213Hz, 2.5%, 2.1, ok, model cell, weak triangle up to 1%! ['2017-07-18-ai-invivo-1', 2], # 78Hz, 5%, 2.3, weak, nice model cell with clear triangle up to 10%! ##['2019-06-28-ae', 0], # 477Hz, 10%, 2.6, weak ##['2020-10-27-aa-invivo-1', 4], # 259Hz, 0.5%, 2.0, ok ##['2020-10-27-ae-invivo-1', 4], # 375Hz, 0.5%, 4.3, nice, additional low freq line ###['2020-10-27-ag-invivo-1', 2], # 405Hz, 5%, 3.9, strong, is already the example ##['2021-08-03-ab-invivo-1', 1], # 140Hz, 0.5%, ok #['2020-10-29-ag-invivo-1', 2], # 164Hz, 5%, 1.6, no diagonal ['2018-08-24-ak', 1], # 145Hz, 5%, no diagonal ['2018-08-14-ac', 1], # 239Hz, 0: 10%, no diagonal OR 1: 5% ##['2018-08-29-af', 1], # 383Hz, 5%, no diagonal ] data_path = Path('data') / 'cells' def load_baseline(path, cell_name): d = path / f'{cell_name}-baseline.npz' data = np.load(d) eodf = float(data['eodf']) rate = float(data['ratebase/Hz']) cv = float(data['cvbase']) isis = data['isis'] pdf = data['isih'] freqs = data['freqs'] prr = data['prr'] return eodf, rate, cv, isis, pdf, freqs, prr def load_noise(path, cell_name, run): data = np.load(path / f'{cell_name}-spectral-data-s{run:02d}.npz') contrast = data['contrast'] time = data['time'] stimulus = data['stimulus'] spikes = [] for k in range(1000): key = f'spikes_{k:03d}' if not key in data.keys(): break spikes.append(data[key]) return contrast, time, stimulus, spikes def load_spectra(path, mode, cell_name, run=None): data = np.load(path / f'{cell_name}-spectral-{mode}-s{run:02d}.npz') contrast = float(data['contrast']) fcutoff = float(data['fcutoff']) freqs = data['freqs'] pss = data['pss'] prs = data['prs'] prss = data['prss'] nsegs = int(data['nsegs']) gain = np.abs(prs)/pss chi2 = np.abs(prss)*0.5/(pss.reshape(1, -1)*pss.reshape(-1, 1)) return fcutoff, contrast, freqs, gain, chi2 def plot_isih(ax, s, rate, cv, isis, pdf): ax.show_spines('b') ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1) ax.set_xlim(0, 8) ax.set_xticks_delta(2) ax.set_xlabel('ISI', 'ms') ax.text(0.6, 1.08, f'CV$_{{\\rm base}}$={cv:.2f}, $r={rate:.0f}$Hz', transform=ax.transAxes) def plot_isih_small(ax, s, contrast, rate, cv, isis, pdf): ax.show_spines('b') ax.fill_between(1000*isis, pdf, facecolor=s.cell_color1) ax.set_xlim(0, 20) ax.set_xticks_fixed([0, 5, 10, 15, 20], ['0', '5', '10', '15', '20\\,ms']) xt = 1 if rate > 80 else 1.3 ax.text(xt, 1.05, f'CV$_{{\\rm base}}$={cv:.2f}', ha='right', transform=ax.transAxes) ax.text(xt, 0.6, f'$r={rate:.0f}$Hz', ha='right', transform=ax.transAxes) ax.text(xt, 0.15, f'$c={100*contrast:.0f}$\\%', ha='right', transform=ax.transAxes) def plot_response_spectrum(ax, s, eodf, rate, freqs, prr): rate_i = np.argmax(prr[freqs < 0.7*eodf]) eod_i = np.argmax(prr[freqs > 500]) + np.argmax(freqs > 500) power_db = 10*np.log10(prr/np.max(prr)) ax.show_spines('b') mask = (freqs > 30) & (freqs < 890) ax.plot(freqs[mask], power_db[mask], **s.lsC1) ax.plot(freqs[eod_i], power_db[eod_i] + 2, **s.psFEOD) ax.plot(freqs[rate_i], power_db[rate_i] + 2, **s.psF0) ax.set_ylim(-25, 5) #ax.plot(freqs[mask], 1e-3*prr[mask], **s.lsC1) #ax.plot(freqs[eod_i], 1e-3*prr[eod_i] + 2, **s.psFEOD) #ax.plot(freqs[rate_i], 1e-3*prr[rate_i] + 2, **s.psF0) #ax.set_ylim(0, 30) ax.set_xlim(0, 900) ax.set_xticks_delta(300) ax.set_xlabel('$f$', 'Hz') ax.text(freqs[eod_i], power_db[eod_i] + 4, '$f_{\\rm EOD}$', ha='center') ax.text(freqs[rate_i], power_db[rate_i] + 4, '$r$', ha='center') ax.yscalebar(1.05, 0, 10, 'dB', ha='right') #ax.text(freqs[eod_i], 1e-3*prr[eod_i] + 4, '$f_{\\rm EOD}$', # ha='center') #ax.text(freqs[rate_i], 1e-3*prr[rate_i] + 4, '$r$', # ha='center') #ax.yscalebar(1.05, 0, 5, 'kHz', ha='right') def plot_response(ax, s, eodf, time1, stimulus1, contrast1, spikes1, contrast2, spikes2, am=True): t0 = 0.3 t1 = 0.4 maxtrials = 8 trials = np.arange(maxtrials) ax.show_spines('') ax.eventplot(spikes1[2:2+maxtrials], lineoffsets=trials - maxtrials + 1, linelength=0.8, linewidths=1, color=s.cell_color1) ax.eventplot(spikes2[2:2+maxtrials], lineoffsets=trials - 2*maxtrials, linelength=0.8, linewidths=1, color=s.cell_color2) stim = contrast1*stimulus1 if am: eod = np.sin(2*np.pi*eodf*time1) * (1 + stim) else: eod = np.sin(2*np.pi*eodf*time1) + stim ax.plot(time1, 4*eod + 7, **s.lsEOD) ax.plot(time1, 4*(1 + stim) + 7, **s.lsAM) ax.set_xlim(t0, t1) ax.set_ylim(-2*maxtrials - 0.5, 14) ax.xscalebar(1, -0.05, 0.01, None, '10\\,ms', ha='right') ax.text(t1 + 0.003, -0.5*maxtrials, f'${100*contrast1:.0f}$\\,\\%', va='center', color=s.cell_color1) ax.text(t1 + 0.003, -1.55*maxtrials, f'${100*contrast2:.0f}$\\,\\%', va='center', color=s.cell_color2) def plot_gain(ax, s, contrast1, freqs1, gain1, contrast2, freqs2, gain2, fcutoff, ymax, dy): ax.plot(freqs2, 1e-2*gain2, label=f'{100*contrast2:.0f}', **s.lsC2) ax.plot(freqs1, 1e-2*gain1, label=f'{100*contrast1:.0f}', **s.lsC1) ax.set_xlim(0, fcutoff) ax.set_ylim(0, ymax) ax.set_xticks_delta(fcutoff//3) ax.set_yticks_delta(dy) ax.set_xlabel('$f$', 'Hz') ax.set_ylabel(r'$|\chi_1|$', r'Hz/\%') def plot_diagonals(ax, s, fbase, contrast1, freqs1, chi21, contrast2, freqs2, chi22, fcutoff, ymax, toffs): diags = [] sis = [] sips = [] sifs = [] for contrast, freqs, chi2 in [[contrast1, freqs1, chi21], [contrast2, freqs2, chi22]]: dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff) diags.append([dfreqs, diag]) sinorm, sirel, sif = peak_size(dfreqs, diag, fbase, median=False) sip = diag[np.argmin(np.abs(dfreqs - sif))] sis.append(sinorm) sips.append(sip) sifs.append(sif) print(f' SI at {100*contrast:.1f}% contrast: {sinorm:.2f}') ax.plot(diags[1][0], 1e-4*diags[1][1], **s.lsC2) ax.plot(diags[0][0], 1e-4*diags[0][1], **s.lsC1) offs = 0.05*ymax ax.plot(sifs[1], 1e-4*sips[1] + offs, clip_on=False, **s.psC2) ax.plot(sifs[0], 1e-4*sips[0] + offs, clip_on=False, **s.psC1) ax.set_xlim(0, 2*fcutoff) ax.set_ylim(0, ymax) ax.set_xticks_delta(fcutoff) ax.set_yticks_delta(ymax//3) ax.set_xlabel('$f_1 + f_2$', 'Hz') ax.text(sifs[1] - 0.15*fcutoff, 1e-4*sips[1], f'{100*contrast2:.0f}\\%', ha='right', color=s.cell_color2) ax.text(sifs[1] + 0.25*fcutoff, 1e-4*sips[1], f'SI={sis[1]:.1f}') ax.text(sifs[0] - 0.15*fcutoff, 1e-4*sips[0] + toffs, f'{100*contrast1:.0f}\\%', ha='right', color=s.cell_color1) ax.text(sifs[0] + 0.25*fcutoff, 1e-4*sips[0] + toffs, f'SI={sis[0]:.1f}') if __name__ == '__main__': """ # find a nice example cell: from thunderlab.tabledata import TableData data = TableData('data/Apteronotus_leptorhynchus-Punit-data.csv') data = data[(data['sinorm_nmax'] > 0) & (data['sinorm_nmax'] < 1.5), :] data = data[(data['contrast'] > 0.04) & (data['contrast'] < 0.06), :] #data = data[(data['respmod2'] > 150) & (data['respmod2'] < 200), :] data = data[(data['cvbase'] > 0.4) & (data['cvbase'] < 0.8), :] data = data[(data['ratebase'] > 220) & (data['ratebase'] < 300), :] for k in range(data.rows()): print(f'{data[k, "cell"]:<22s} s{data[k, "stimindex"]:02.0f}: ' f'{100*data[k, "contrast"]:3g}%, r={data[k, "ratebase"]:3.0f}Hz, ' f'CV={data[k, "cvbase"]:4.2f}, ' f'rmod={data[k, "respmod2"]:3.0f}Hz, ' f'SI={data[k, "sinorm_nmax"]:5.2f}') print() #exit() """ #mode = 'all' mode = '100' cell_name = example_cell[0][0] print('Example P-unit:') eodf, rate, cv, isis, pdf, freqs, prr = load_baseline(data_path, cell_name) print(f' {cell_name:<22s}: fbase={rate:3.0f}Hz, CV={cv:.2f}') contrast1, time1, stimulus1, spikes1 = load_noise(data_path, *example_cell[0]) contrast2, time2, stimulus2, spikes2 = load_noise(data_path, *example_cell[1]) fcutoff1, contrast1, freqs1, gain1, chi21 = load_spectra(data_path, mode, *example_cell[0]) fcutoff2, contrast2, freqs2, gain2, chi22 = load_spectra(data_path, mode, *example_cell[1]) print(f' contrast1: {100*contrast1:4.1f}% contrast2: {100*contrast2:4.1f}%') print(f' fcutoff1 : {fcutoff1:3.0f}Hz fcutoff2 : {fcutoff2:3.0f}Hz') print(f' duration1: {time1[-1]:4.1f}s duration2: {time2[-1]:4.1f}s') s = plot_style() s.cell_color1 = s.punit_color1 s.cell_color2 = s.punit_color2 s.lsC1 = s.lsP1 s.lsC2 = s.lsP2 s.psC1 = s.psP1 s.psC2 = s.psP2 fig, (ax1, ax2, ax3) = \ plt.subplots(3, 1, height_ratios=[3, 0, 3, 0.3, 4.7], cmsize=(s.plot_width, 0.85*s.plot_width)) fig.subplots_adjust(leftm=8, rightm=2, topm=2, bottomm=3.5, wspace=0.4, hspace=0.42) axi, axp, axr = ax1.subplots(1, 3, width_ratios=[2, 3, 0, 10, 0.2]) axg, axc1, axc2, axd = ax2.subplots(1, 4, wspace=0.2, width_ratios=[3.5, 0.5, 4, 4, 0.8, 3.5]) axs = ax3.subplots(2, 4, wspace=0.4, hspace=0.35, width_ratios=[1, 1, 0.1, 1, 1, 0.1], height_ratios=[1, 4]) axi.text(0, 1.08, 'P-unit:', transform=axi.transAxes, color=s.cell_color1, fontsize='large') plot_isih(axi, s, rate, cv, isis, pdf) plot_response_spectrum(axp, s, eodf, rate, freqs, prr) plot_response(axr, s, eodf, time1, stimulus1, contrast1, spikes1, contrast2, spikes2, am=True) plot_gain(axg, s, contrast1, freqs1, gain1, contrast2, freqs2, gain2, fcutoff1, ymax=10, dy=5) axc = plot_chi2(axc1, s, freqs2, chi22, fcutoff2, None, 6) axc.remove() axc1.plot([0, fcutoff2], [0, fcutoff2], zorder=20, **s.lsDiag) axc1.set_title(f'$c$={100*contrast2:g}\\,\\%', fontsize='medium', color=s.cell_color2) plot_chi2(axc2, s, freqs1, chi21, fcutoff1, None, 6) axc2.set_title(f'$c$={100*contrast1:g}\\,\\%', fontsize='medium', color=s.cell_color1) axc2.plot([0, fcutoff1], [0, fcutoff1], zorder=20, **s.lsDiag) plot_diagonals(axd, s, rate, contrast1, freqs1, chi21, contrast2, freqs2, chi22, fcutoff1, ymax=14, toffs=2.5) fig.common_yticks(axc1, axc2) fig.tag([axi, axp, axr], xoffs=-3, yoffs=0) fig.tag([axg, axc1, axc2, axd], xoffs=-3, yoffs=2) print() print('Additional example cells:') axs[0, 0].text(0, 1.6, 'P-units:', transform=axs[0, 0].transAxes, color=s.cell_color1, fontsize='large') for k, (cell, run) in enumerate(example_cells): eodf, rate, cv, isis, pdf, _, _ = load_baseline(data_path, cell) fcutoff, contrast, freqs, gain, chi2 = load_spectra(data_path, mode, cell, run) print(f' {cell:<22s}: run={run:2d}, contrast={100*contrast:3.2g}%, ' f'fbase={rate:3.0f}Hz, CV={cv:.2f}') plot_isih_small(axs[0, k], s, contrast, rate, cv, isis, pdf) vmax = 20 if k < 2 else 30 axc = plot_chi2(axs[1, k], s, freqs, chi2, fcutoff, rate, vmax) if k % 2 == 0: axc.remove() if k == 1: axc.set_ylabel('') fig.common_yticks(axs[1, :]) fig.tag([axs[0, :2], axs[0, 2:]], xoffs=-3, yoffs=1) fig.savefig() print()