import numpy as np import matplotlib as mpl import plottools.plottools as pt from spectral import diag_projection, peak_size from plottools.spines import spines_params from plottools.labels import labels_params from plottools.colors import lighter, darker def significance_str(p): if p > 0.05: return f'$p={p:.2f}$' elif p > 0.01: return '$p<0.05$' elif p > 0.001: return '$p<0.01$' elif p > 0.0001: return '$p<0.001$' else: return '$p\\ll 0.001$' def noise_files(data_path, cell_name, alpha=None): if alpha is None: file_pattern = f'{cell_name}-chi2-split-*.npz' else: file_pattern = f'{cell_name}-chi2-noise-{1000*alpha:03.0f}-*.npz' files = sorted(data_path.glob(file_pattern), key=lambda x: x.stem) if len(files) == 0: return None, 0 nums = [int(fn.stem.split('-')[-1]) for fn in files] idxs = np.argsort(nums) files = [files[i] for i in idxs] nums = [nums[i] for i in idxs] return files, nums def plot_chi2(ax, s, freqs, chi2, fcutoff, rate=None, vmax=None): ax.set_visible(True) ax.set_aspect('equal') i0 = np.argmin(freqs < 0) i1 = np.argmax(freqs > fcutoff) if i1 == 0: i1 = len(freqs) freqs = freqs[i0:i1] chi2 = 1e-4*chi2[i0:i1, i0:i1] # Hz/%^2 vquantile = 0.996 if vmax is None: vmax = np.quantile(chi2, vquantile) ten = 10**np.floor(np.log10(vmax)) for fac, delta in zip([1, 1.2, 1.5, 2, 3, 4, 6, 8, 10], [0.5, 0.4, 0.5, 1, 1, 2, 3, 4, 5]): if fac*ten >= vmax: #vmax = prev_fac*ten #ten *= prev_delta vmax = fac*ten ten *= delta break #prev_fac = fac #prev_delta = delta pc = ax.pcolormesh(freqs, freqs, chi2, vmin=0, vmax=vmax, rasterized=True) ax.set_xlim(0, fcutoff) ax.set_ylim(0, fcutoff) df = 100 if fcutoff > 250 else 50 ax.set_xticks_delta(df) ax.set_yticks_delta(df) ax.set_xlabel('$f_1$', 'Hz') ax.set_ylabel('$f_2$', 'Hz') if rate is not None: dfreqs, diag = diag_projection(freqs, chi2, 2*fcutoff) nli, nlirel, nlif = peak_size(dfreqs, diag, rate, median=False) ax.text(0.95, 0.88, f'SI($r$)={nli:.1f}', ha='right', zorder=50, color='white', fontsize='medium', transform=ax.transAxes, bbox=dict(boxstyle='round,pad=0.1', ec='none', fc='black', alpha=0.4)) cax = ax.inset_axes([1.04, 0, 0.05, 1]) cax.set_spines_outward('lrbt', 0) cb = ax.get_figure().colorbar(pc, cax=cax) cb.outline.set_color('none') cb.outline.set_linewidth(0) cax.set_ylabel(r'$|\chi_2|$', r'Hz/\%$^2$') cax.set_yticks_delta(ten) return cax def plot_style(): palette = pt.palettes['muted'] lwthick = 1.6 lwmid = 1.0 lwthin = 0.5 lwspines = 1.0 names = ['A1', 'A2', 'A3', 'B1', 'B2', 'B3', 'B4', 'C1', 'C2', 'C3', 'C4'] colors = [palette['red'], palette['orange'], palette['yellow'], palette['blue'], palette['purple'], palette['magenta'], palette['lightblue'], palette['lightgreen'], palette['green'], palette['darkgreen'], palette['cyan']] dashes = ['-', '-', '-', '-', '-', '-', '-.', '-', '-', '-', '-'] markers = [('o', 1.0), ('p', 1.1), ('h', 1.1), ((3, 1, 60), 1.25), ((3, 1, 0), 1.25), ((3, 1, 90), 1.25), ((3, 1, 30), 1.25), ('s', 0.9), ('D', 0.85), ('*', 1.6), ((4, 1, 45), 1.4)] class ns: pass ns.colors = palette ns.lwthick = lwthick ns.lwmid = lwmid ns.lwthin = lwthin ns.plot_width = 16.5 pt.make_linepointfill_styles(ns, names, colors, dashes, markers, lwthick=lwthick, lwthin=lwthin, markerlarge=5, markersmall=4, mec=0.0, mew=0.5, fillalpha=0.4) pt.make_line_styles(ns, 'ls', 'Spine', '', palette['black'], '-', lwspines, clip_on=False) pt.make_line_styles(ns, 'ls', 'Grid', '', palette['gray'], '--', 0.7*lwthin) pt.make_line_styles(ns, 'ls', 'Dotted', '', palette['gray'], ':', 0.7*lwthin) pt.make_line_styles(ns, 'ls', 'Marker', '', palette['black'], '-', lwthick, clip_on=False) pt.make_line_styles(ns, 'ls', 'Line', '', palette['black'], '-', lwthin) pt.make_line_styles(ns, 'ls', 'EOD', '', palette['gray'], '-', lwthin) pt.make_line_styles(ns, 'ls', 'AM', '', palette['red'], '-', lwmid) pt.make_line_styles(ns, 'ls', 'AMsplit', '', palette['orange'], '-', lwmid) pt.make_line_styles(ns, 'ls', 'Noise', '', palette['gray'], '-', lwmid) pt.make_line_styles(ns, 'ls', 'Median', '', palette['black'], '-', lwthick) pt.make_line_styles(ns, 'ls', 'Max', '', palette['black'], '-', lwmid) ns.lsStim = dict(color='gray', lw=ns.lwmid) ns.lsRaster = dict(color='black', lw=ns.lwthin) ns.lsPower = dict(color='gray', lw=ns.lwmid) ns.lsF0 = dict(color='blue', lw=ns.lwthick) ns.lsF01 = dict(color='green', lw=ns.lwthick) ns.lsF02 = dict(color='purple', lw=ns.lwthick) ns.lsF012 = dict(color='orange', lw=ns.lwthick) ns.lsF01_2 = dict(color='red', lw=ns.lwthick) ns.lsF0m = dict(color=lighter('blue', 0.5), lw=ns.lwthin) ns.lsF01m = dict(color=lighter('green', 0.6), lw=ns.lwthin) ns.lsF02m = dict(color=lighter('purple', 0.5), lw=ns.lwthin) ns.lsF012m = dict(color=darker('orange', 0.9), lw=ns.lwthin) ns.lsF01_2m = dict(color=darker('red', 0.9), lw=ns.lwthin) ns.psFEOD = dict(color='black', marker='o', linestyle='none', markersize=5, mec='none', mew=0) ns.psF0 = dict(color='blue', marker='o', linestyle='none', markersize=5, mec='none', mew=0) ns.psF01 = dict(color='green', marker='o', linestyle='none', markersize=5, mec='none', mew=0) ns.psF02 = dict(color='purple', marker='o', linestyle='none', markersize=5, mec='none', mew=0) ns.psF012 = dict(color='orange', marker='o', linestyle='none', markersize=5, mec='none', mew=0) ns.psF01_2 = dict(color='red', marker='o', linestyle='none', markersize=5, mec='none', mew=0) ns.model_color1 = palette['purple'] ns.model_color2 = lighter(ns.model_color1, 0.6) ns.punit_color1 = palette['blue'] ns.punit_color2 = lighter(ns.punit_color1, 0.6) ns.ampul_color1 = palette['green'] ns.ampul_color2 = lighter(ns.ampul_color1, 0.6) pt.make_line_styles(ns, 'ls', 'M%d', '', [ns.model_color1, ns.model_color2], '-', lwmid) pt.make_point_styles(ns, 'ps', 'M%d', '', [ns.model_color1, ns.model_color2], '-', markers=('o', 1), markersizes=4) pt.make_line_styles(ns, 'ls', 'P%d', '', [ns.punit_color1, ns.punit_color2], '-', lwmid) pt.make_point_styles(ns, 'ps', 'P%d', '', [ns.punit_color1, ns.punit_color2], '-', markers=('o', 1), markersizes=4) pt.make_line_styles(ns, 'ls', 'A%d', '', [ns.ampul_color1, ns.ampul_color2], '-', lwmid) pt.make_point_styles(ns, 'ps', 'A%d', '', [ns.ampul_color1, ns.ampul_color2], '-', markers=('o', 1), markersizes=4) pt.make_line_styles(ns, 'ls', 'Diag', '', palette['white'], '--', lwthin) pt.arrow_style(ns, 'Line', dist=3.0, style='>', shrink=0, lw=0.6, color=palette['black'], head_length=4, head_width=4, bbox=dict(boxstyle='round,pad=0.1', facecolor='white', edgecolor='none', alpha=1.0)) pt.arrow_style(ns, 'Hertz', dist=3.0, style='>', shrink=0, lw=0.6, color=palette['black'], head_length=3, head_width=3, heads='<>', text='%.0f\u2009Hz', rotation='vertical', fontsize='x-small', bbox=dict(boxstyle='round,pad=0.1', facecolor='white', edgecolor='none', alpha=0.6)) pt.arrow_style(ns, 'Point', dist=3.0, style='>>', shrink=0, lw=1.1, color=palette['black'], head_length=5, head_width=4) pt.arrow_style(ns, 'PointSmall', dist=3.0, style='>>', shrink=0, lw=1, color=palette['black'], head_length=5, head_width=5, fontsize='small') pt.arrow_style(ns, 'Marker', dist=3.0, style='>>', shrink=0, lw=0.9, color=palette['black'], head_length=5, head_width=5, fontsize='small', ha='center', va='center', bbox=dict(boxstyle='round, pad=0.1, rounding_size=0.4', facecolor=palette['white'], edgecolor='none', alpha=0.6)) # rc settings: mpl.rcdefaults() pt.axes_params(0.0, 0.0, 0.0, color='none') cmcolors = [pt.lighter(palette['yellow'], 0.2), pt.lighter(palette['orange'], 0.5), palette['orange'], palette['red'], palette['red']] cmvalues = [0.0, 0.3, 0.7, 0.95, 1.0] pt.colormap('YR', cmcolors, cmvalues) cycle_colors = ['blue', 'red', 'orange', 'lightgreen', 'magenta', 'yellow', 'cyan', 'pink'] pt.colors_params(palette, cycle_colors, 'viridis') pt.figure_params(palette['white'], format='pdf', compression=6, fonttype=3, stripfonts=True) pt.labels_params('{label} [{unit}]', labelsize='medium', labelpad=6) pt.legend_params(fontsize='medium', frameon=False, borderpad=0.0, handlelength=1.5, handletextpad=1, numpoints=1, scatterpoints=1, labelspacing=0.5, columnspacing=0.5) pt.scalebars_params(format_large='%.0f', format_small='%.1f', lw=2.2, color=palette['black'], capsize=0, clw=0.5, font=dict(fontsize='medium', fontstyle='normal')) pt.spines_params(spines='lb', spines_offsets={'lrtb': 3}, spines_bounds={'lrtb': 'full'}) pt.tag_params(xoffs='auto', yoffs='auto', label='%A', minor_label=r'%A$_{\text{%mi}}$', font=dict(fontsize='x-large', fontstyle='normal', fontweight='normal')) pt.text_params(font_size=7, font_family='sans-serif', latex=True, preamble=('p:sfmath', 'p:marvosym', 'p:xfrac', 'p:SIunits')) pt.ticks_params(xtick_dir='out', xtick_size=2.5) return ns