diff --git a/figures/fig_invariance_thresh_lp_species.pdf b/figures/fig_invariance_thresh_lp_species.pdf index b2909e1..74f3ff3 100644 Binary files a/figures/fig_invariance_thresh_lp_species.pdf and b/figures/fig_invariance_thresh_lp_species.pdf differ diff --git a/python/fig_invariance_thresh-lp_species.py b/python/fig_invariance_thresh-lp_species.py index 1962609..e9901c7 100644 --- a/python/fig_invariance_thresh-lp_species.py +++ b/python/fig_invariance_thresh-lp_species.py @@ -1,6 +1,7 @@ import plotstyle_plt import numpy as np import matplotlib.pyplot as plt +from matplotlib.colors import LogNorm from mpl_toolkits.axes_grid1 import make_axes_locatable from itertools import product from thunderhopper.filetools import search_files @@ -129,6 +130,40 @@ def shorten_species(name): genus, species = name.split('_') return genus[0] + '. ' + species +def add_cross_axes(fig, n, long='col', fill='row', **grid_kwargs): + n_axes = n * (n - 1) // 2 + nrows = grid_kwargs.get('nrows', None) + ncols = grid_kwargs.get('ncols', None) + if nrows is None or ncols is None: + if nrows is not None: + ncols = int(np.ceil(n_axes / nrows)) + elif ncols is not None: + nrows = int(np.ceil(n_axes / ncols)) + else: + nrows = int(np.ceil(np.sqrt(n_axes))) + ncols = int(np.ceil(n_axes / nrows)) + if long == 'col' and ncols < nrows: + nrows, ncols = ncols, nrows + elif n_axes > nrows * ncols: + msg = f'Cannot place {n_axes} subplots in a {nrows}x{ncols} grid.' + raise ValueError(msg) + + row_inds = [i for i in range(n) for j in range(i + 1, n)] + col_inds = [j for i in range(n) for j in range(i + 1, n)] + if fill == 'col': + positions = [(j, i) for i, j in product(range(ncols), range(nrows))] + row_inds, col_inds = col_inds, row_inds + else: + positions = list(product(range(nrows), range(ncols))) + positions = np.array(positions[:n_axes]) + + grid = fig.add_gridspec(**(grid_kwargs | dict(nrows=nrows, ncols=ncols))) + axes = [] + for i, j in positions: + axes.append(fig.add_subplot(grid[i, j])) + return axes, positions, grid, row_inds, col_inds + + # GENERAL SETTINGS: target_species = [ 'Omocestus_rufipes', @@ -152,16 +187,16 @@ kern_specs = np.array([ [1, 0.008], [2, 0.004], [3, 0.002], -])[np.array([0, 1])] +])[np.array([0, 1, 2])] n_kernels = kern_specs.shape[0] # GRAPH SETTINGS: fig_kwargs = dict( - figsize=(32/2.54, 20/2.54), + figsize=(32/2.54, 32/2.54), ) super_grid_kwargs = dict( nrows=3, - ncols=1, + ncols=2, wspace=0, hspace=0, left=0, @@ -171,15 +206,16 @@ super_grid_kwargs = dict( height_ratios=[1, 4, 3] ) subfig_specs = dict( - song=(0, 0), - feat=(1, 0), - space=(2, 0) + song=(0, slice(None)), + feat=(1, slice(None)), + pure=(2, 0), + noise=(2, 1), ) feat_grid_kwargs = dict( nrows=2, ncols=n_species, wspace=0.25, - hspace=0.15, + hspace=0.1, left=0.06, right=0.985, bottom=0.1, @@ -196,19 +232,19 @@ song_grid_kwargs = dict( top=0.8 ) space_grid_kwargs = dict( - nrows=1, - ncols=2, - wspace=0.2, - hspace=0, - left=feat_grid_kwargs['left'], - right=feat_grid_kwargs['right'], - bottom=0.05, + nrows=None, + ncols=None, + wspace=0.1, + hspace=0.3, + left=0.05, + right=1, + bottom=0.1, top=0.95 ) anchor_kwargs = dict( aspect='equal', adjustable='box', - anchor=(0, 0.5) + anchor=(0.5, 0.5) ) inset_kwargs = dict( y0=0.7, @@ -226,8 +262,8 @@ fs = dict( bar=16, ) species_colors = load_colors('../data/species_colors.npz') -kernel_shades = [0, 0.5] -# scale_shades = [1, 0] +kernel_shades = [0, 0.75] +scale_shades = [1, 0] lw = dict( song=0.5, feat=3, @@ -246,11 +282,11 @@ space_kwargs = dict( ) xlabels = dict( feat='scale $\\alpha$', - space='$\\mu_{f_1}$' + space=[f'$\\mu_{{f_{i}}}$' for i in range(1, n_kernels + 1)], ) ylabels = dict( feat='$\\mu_f$', - space='$\\mu_{f_2}$', + space=[f'$\\mu_{{f_{i}}}$' for i in range(1, n_kernels + 1)], bar='scale $\\alpha$', ) xlab_feat_kwargs = dict( @@ -260,7 +296,7 @@ xlab_feat_kwargs = dict( va='bottom', ) xlab_space_kwargs = dict( - y=0, + y=-0.3, fontsize=fs['lab_tex'], ha='center', va='bottom', @@ -268,14 +304,14 @@ xlab_space_kwargs = dict( ylab_feat_kwargs = dict( x=0, fontsize=fs['lab_tex'], - ha='left', - va='center', + ha='center', + va='top', ) ylab_space_kwargs = dict( - x=0, + x=-0.2, fontsize=fs['lab_tex'], - ha='left', - va='center', + ha='center', + va='bottom', ) ylab_cbar_kwargs = dict( x=1, @@ -284,6 +320,7 @@ ylab_cbar_kwargs = dict( va='bottom', ) xloc = dict( + feat=(1,), space=0.5, ) yloc = dict( @@ -302,17 +339,24 @@ title_kwargs = dict( fontstyle='italic' ) letter_feat_kwargs = dict( - x=0, - yref=1, - ha='center', - va='top', + xref=0, + y=1, + ha='left', + va='center', fontsize=fs['letter'], ) +letter_song_kwargs = dict( + x=0, + y=1, + ha='left', + va='top', + fontsize=fs['letter'], +) letter_space_kwargs = dict( x=0, yref=1, - ha='center', - va='top', + ha='left', + va='center', fontsize=fs['letter'], ) song_bar_time = 1.0 @@ -325,33 +369,29 @@ song_bar_kwargs = dict( lw=0, clip_on=False, # text_pos=(-0.1, 0.5), - text_str=f'${int(1000 * song_bar_time)}\\,\\text{{ms}}$', - text_kwargs=dict( - fontsize=fs['bar'], - ha='right', - va='center', - ) + # text_str=f'${int(1000 * song_bar_time)}\\,\\text{{ms}}$', + # text_kwargs=dict( + # fontsize=fs['bar'], + # ha='right', + # va='center', + # ) ) kern_bar_time = 0.05 kern_bar_kwargs = dict( dur=kern_bar_time, - y0=inset_kwargs['y0'], - y1=inset_kwargs['y0'] + 0.03, + y0=inset_kwargs['y0'] - 0.03, + y1=inset_kwargs['y0'], color='k', lw=0 ) -cbar_bounds = [ - 0.05, - space_grid_kwargs['bottom'], - 0.15, - space_grid_kwargs['top'] - space_grid_kwargs['bottom'] -] noise_kwargs = dict( fc=(0.9, 0.9, 0.9), ec='none', lw=0, zorder=0.5, ) +low_rel_thresh = 0.05 +high_rel_thresh = 0.95 # EXECUTION: @@ -368,6 +408,7 @@ for i in range(n_species): hide_axis(ax, 'bottom') hide_axis(ax, 'left') song_axes[i] = ax +letter_subplot(song_subfig, 'a', **letter_song_kwargs) # Prepare feature invariance axes: feat_subfig = fig.add_subfigure(super_grid[subfig_specs['feat']]) @@ -377,12 +418,13 @@ for i, j in product(range(feat_grid_kwargs['nrows']), range(n_species)): ax = feat_subfig.add_subplot(feat_grid[i, j]) ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['feat'])) ax.set_ylim(0, 1) + if j == 0: + ylabel(ax, ylabels['feat'], transform=feat_subfig, **ylab_feat_kwargs) feat_axes[i, j] = ax -super_xlabel(xlabels['feat'], feat_subfig, feat_axes[-1, 0], feat_axes[-1, -1], **xlab_feat_kwargs) -super_ylabel(ylabels['feat'], feat_subfig, feat_axes[-1, 0], feat_axes[0, 0], **ylab_feat_kwargs) [hide_ticks(ax, side='bottom') for ax in feat_axes[0, :]] [hide_ticks(ax, side='left') for ax in feat_axes[:, 1:].ravel()] -letter_subplots(feat_axes[0, :], labels='abc', ref=feat_subfig, **letter_feat_kwargs) +super_xlabel(xlabels['feat'], feat_subfig, feat_axes[-1, 0], feat_axes[-1, -1], **xlab_feat_kwargs) +letter_subplots(feat_axes[:, 0], labels='bc', ref=feat_subfig, **letter_feat_kwargs) # Prepare kernel insets: x0 = np.linspace(0, 1, n_kernels + 1)[:-1] + 1 / n_kernels / 2 @@ -395,36 +437,49 @@ for i in range(n_kernels): inset.axis('off') insets.append(inset) -# Prepare feature space axes: -space_subfig = fig.add_subfigure(super_grid[subfig_specs['space']]) -space_grid = space_subfig.add_gridspec(**space_grid_kwargs) -space_axes = np.zeros(space_grid_kwargs['ncols'], dtype=object) -for i in range(space_axes.size): - ax = space_subfig.add_subplot(space_grid[i]) - ax.set_xlim(0, 1) - ax.set_ylim(0, 1) - ax.xaxis.set_major_locator(plt.MultipleLocator(xloc['space'])) - ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['space'])) - ax.set_aspect(**anchor_kwargs) - # ax.set_ylabel(ylabels['space'], **ylab_space_kwargs) - ylabel(ax, ylabels['space'], transform=space_subfig.transSubfigure, **ylab_space_kwargs) - space_axes[i] = ax -super_xlabel(xlabels['space'], space_subfig, space_axes[1], space_axes[1], **xlab_space_kwargs) -hide_ticks(space_axes[0], side='bottom') -letter_subplot(space_axes[0], 'd', ref=space_subfig, **letter_space_kwargs) +# Prepare pure feature space axes: +pure_subfig = fig.add_subfigure(super_grid[subfig_specs['pure']]) +outputs = add_cross_axes(pure_subfig, n_kernels, **space_grid_kwargs) +pure_axes, space_pos, space_grid, row_inds, col_inds = outputs +letter_subplot(pure_subfig, 'd', ref=pure_axes[0], **letter_space_kwargs) -# Prepare colorbars: -cbar_bounds[0] += space_axes[-1].get_position().x1 -bar_axes = [space_subfig.add_axes(cbar_bounds)] -bar_axes.extend(split_subplot(bar_axes[0], side=['right'] * (n_species - 1), - size=100, pad=0)) +# Prepare noise feature space axes: +noise_subfig = fig.add_subfigure(super_grid[subfig_specs['noise']]) +noise_axes = add_cross_axes(noise_subfig, n_kernels, **space_grid_kwargs)[0] +letter_subplot(noise_subfig, 'e', ref=noise_axes[0], **letter_space_kwargs) + +# Format feature space axes: +for ind, axes in enumerate(zip(pure_axes, noise_axes)): + for ax in axes: + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.xaxis.set_major_locator(plt.MultipleLocator(xloc['space'])) + ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['space'])) + ax.set_aspect(**anchor_kwargs) + xlabel(ax, xlabels['space'][col_inds[ind]], **xlab_space_kwargs) + ylabel(ax, ylabels['space'][row_inds[ind]], **ylab_space_kwargs) + +# Determine area to place colorbars: +rightmost = pure_axes[np.argmax(space_pos[:, 1])].get_position() +downmost = pure_axes[np.argmax(space_pos[:, 0])].get_position() +bar_bounds = [rightmost.x0, downmost.y0, rightmost.width, downmost.height] + +# Prepare pure colorbars: +pure_bars = [pure_subfig.add_axes(bar_bounds)] +pure_bars.extend(split_subplot(pure_bars[0], side=['right'] * (n_species - 1), + size=100, pad=0)) + +# Prepare noise colorbars: +noise_bars = [noise_subfig.add_axes(bar_bounds)] +noise_bars.extend(split_subplot(noise_bars[0], side=['right'] * (n_species - 1), + size=100, pad=0)) # Prepare kernel-specific color shading: kern_factors = np.linspace(*kernel_shades, n_kernels) kern_colors_bw = shade_colors((0., 0., 0.), kern_factors) # Plot results per species: -min_feat = np.zeros((n_species, n_kernels), dtype=float) +noise_feat = np.zeros((n_species, n_kernels), dtype=float) for i, species in enumerate(target_species): print(f'Processing {species}') @@ -464,21 +519,19 @@ for i, species in enumerate(target_species): scales = scales[nonzero_inds] pure_measure = pure_measure[nonzero_inds, :] noise_measure = noise_measure[nonzero_inds, :] - min_feat[i, :] = noise_measure.min(axis=0) # Prepare species-specific colors: base_color = species_colors[species] kern_colors = shade_colors(base_color, kern_factors) - scale_factors = np.linspace(1, 0, scales.size) + scale_factors = np.linspace(*scale_shades, scales.size) scale_cmap = create_listed_cmap(shade_colors(base_color, scale_factors)) scale_cmap_bw = create_listed_cmap(shade_colors((0., 0., 0.), scale_factors)) # Plot feature invariance curves: - pure_ax, noise_ax = feat_axes[:, i] symlog_kwargs['linthresh'] = scales[scales > 0][0] [ax.set_xscale('symlog', **symlog_kwargs) for ax in feat_axes[:, i]] - pure_ax.set_xscale('symlog', **symlog_kwargs) - noise_ax.set_xscale('symlog', **symlog_kwargs) + [ax.xaxis.set_major_locator(plt.LogLocator(base=10, subs=xloc['feat'])) for ax in feat_axes[:, i]] + pure_ax, noise_ax = feat_axes[:, i] handles = pure_ax.plot(scales, pure_measure, lw=lw['feat']) [h.set_color(c) for h, c in zip(handles, kern_colors)] handles = noise_ax.plot(scales, noise_measure, lw=lw['feat']) @@ -494,30 +547,67 @@ for i, species in enumerate(target_species): inset.set_ylim(ylims) time_bar(insets[0], parent=feat_axes[0, 0], **kern_bar_kwargs) - # Plot pure feature space: - from matplotlib.colors import LogNorm + # Plot invariance curves in feature space: norm = LogNorm(vmin=scales[scales > 0][0], vmax=scales[-1]) - handle = space_axes[0].scatter(pure_measure[:, 0], pure_measure[:, 1], - c=scales, cmap=scale_cmap, norm=norm, - zorder=zorder[species], **space_kwargs) + for ind, (pure_ax, noise_ax) in enumerate(zip(pure_axes, noise_axes)): + irow, icol = row_inds[ind], col_inds[ind] + pure_handle = pure_ax.scatter(pure_measure[:, icol], pure_measure[:, irow], + c=scales, cmap=scale_cmap, norm=norm, + zorder=zorder[species], **space_kwargs) - # Plot noise feature space: - space_axes[1].scatter(noise_measure[:, 0], noise_measure[:, 1], - c=scales, cmap=scale_cmap, norm=norm, - zorder=zorder[species], **space_kwargs) - - # Indicate scale color code: - space_subfig.colorbar(handle, cax=bar_axes[i]) - bar_axes[i].set_yscale('symlog', **symlog_kwargs) + noise_handle = noise_ax.scatter(noise_measure[:, icol], noise_measure[:, irow], + c=scales, cmap=scale_cmap, norm=norm, + zorder=zorder[species], **space_kwargs) + + # Indicate scale color code in pure subfigure: + pure_subfig.colorbar(pure_handle, cax=pure_bars[i]) + pure_bars[i].set_yscale('symlog', **symlog_kwargs) if i < n_species - 1: - hide_ticks(bar_axes[i], 'right', ticks=False) + hide_ticks(pure_bars[i], 'right', ticks=False) else: - ylabel(bar_axes[i], ylabels['bar'], transform=space_subfig.transSubfigure, **ylab_cbar_kwargs) + ylabel(pure_bars[i], ylabels['bar'], transform=pure_subfig.transSubfigure, **ylab_cbar_kwargs) + + # Indicate scale color code in noise subfigure: + noise_subfig.colorbar(noise_handle, cax=noise_bars[i]) + noise_bars[i].set_yscale('symlog', **symlog_kwargs) + if i < n_species - 1: + hide_ticks(noise_bars[i], 'right', ticks=False) + else: + ylabel(noise_bars[i], ylabels['bar'], transform=noise_subfig.transSubfigure, **ylab_cbar_kwargs) + + # Log feature noise floor: + noise_feat[i, :] = noise_measure.min(axis=0) + + # Indicate low and high plateaus: + min_feat = pure_measure.min(axis=0) + span_feat = pure_measure.max(axis=0) - min_feat + + low_thresh = min_feat + low_rel_thresh * span_feat + low_ind = np.nonzero((pure_measure >= low_thresh).all(axis=1))[0][0] + pure_bars[i].axhline(scales[low_ind], c='k', lw=3) + + high_thresh = min_feat + high_rel_thresh * span_feat + high_ind = np.nonzero((pure_measure >= high_thresh).any(axis=1))[0][0] + pure_bars[i].axhline(scales[high_ind], c='w', lw=3) + + # Indicate low and high plateaus: + min_feat = noise_measure.min(axis=0) + span_feat = noise_measure.max(axis=0) - min_feat + + low_thresh = min_feat + low_rel_thresh * span_feat + low_ind = np.nonzero((noise_measure >= low_thresh).all(axis=1))[0][0] + noise_bars[i].axhline(scales[low_ind], c='k', lw=3) + + high_thresh = min_feat + high_rel_thresh * span_feat + high_ind = np.nonzero((noise_measure >= high_thresh).any(axis=1))[0][0] + noise_bars[i].axhline(scales[high_ind], c='w', lw=3) if show_noise: # Indicate feature noise floor: - min_feat = min_feat.mean(axis=0) - space_axes[-1].add_patch(plt.Rectangle((0, 0), min_feat[0], min_feat[1], **noise_kwargs)) + noise_feat = noise_feat.mean(axis=0) + for ind, ax in enumerate(noise_axes): + irow, icol = row_inds[ind], col_inds[ind] + ax.add_patch(plt.Rectangle((0, 0), noise_feat[icol], noise_feat[irow], **noise_kwargs)) if save_path is not None: fig.savefig(save_path) diff --git a/python/plot_functions.py b/python/plot_functions.py index 53b3671..dff93a8 100644 --- a/python/plot_functions.py +++ b/python/plot_functions.py @@ -18,6 +18,8 @@ def hide_axis(ax, side='bottom'): def get_trans_artist(artist): artist_type = type(artist).__name__ + if 'Transform' in artist_type: + return artist if artist_type == 'Axes': return artist.transAxes elif artist_type == 'Figure': @@ -117,6 +119,7 @@ def xlabel(ax, label, x=None, y=-0.1, fontsize=20, transform=None, **kwargs): if x is None: x = 0.5 if transform is not None: + transform = get_trans_artist(transform) x = (ax.transAxes + transform.inverted()).transform((x, 0))[0] ax.xaxis.set_label_coords(x, y, transform=transform) return ax.set_xlabel(label, fontsize=fontsize, **kwargs) @@ -125,6 +128,7 @@ def ylabel(ax, label, x=-0.2, y=None, fontsize=20, transform=None, **kwargs): if y is None: y = 0.5 if transform is not None: + transform = get_trans_artist(transform) y = (ax.transAxes + transform.inverted()).transform((0, y))[1] ax.yaxis.set_label_coords(x, y, transform=transform) return ax.set_ylabel(label, fontsize=fontsize, **kwargs)