import plotstyle_plt import numpy as np import matplotlib.pyplot as plt from itertools import product from thunderhopper.filetools import search_files from thunderhopper.modeltools import load_data from thunderhopper.filtertools import find_kern_specs from misc_functions import get_saturation from color_functions import load_colors, shade_colors from plot_functions import shift_subplot, hide_axis, ylimits, xlabel, ylabel,\ super_ylabel, plot_line, plot_barcode, strip_zeros,\ time_bar, letter_subplot, letter_subplots, title_subplot,\ set_clip_box from IPython import embed def add_snip_axes(fig, grid_kwargs): grid = fig.add_gridspec(**grid_kwargs) axes = np.zeros((grid.nrows, grid.ncols), dtype=object) for i, j in product(range(grid.nrows), range(grid.ncols)): axes[i, j] = fig.add_subplot(grid[i, j]) if j == 0: shift_subplot(axes[i, j], dx=snip_col_shift) [hide_axis(ax, 'left') for ax in axes[:, 2:].flatten()] [hide_axis(ax, 'bottom') for ax in axes.flatten()] return axes def plot_snippets(axes, time, snippets, ymin=None, ymax=None, ypad=0.05, thresh=None, fill_kwargs={}, **kwargs): ymin, ymax = ylimits(snippets, minval=ymin, maxval=ymax, pad=ypad) handles = [] for ax, snippet in zip(axes, snippets.T): handles.append(plot_line(ax, time, snippet, ymin=ymin, ymax=ymax, **kwargs)) if thresh is not None: ax.fill_between(time, thresh, snippet, where=(snippet > thresh), **fill_kwargs) return handles def plot_bi_snippets(axes, time, binary, **kwargs): for ax, binary in zip(axes, binary.T): plot_barcode(ax, time, binary[:, None], **kwargs) return None def side_distributions(axes, snippets, inset_bounds, thresh, nbins=50, limits=None, fill_kwargs={}, **kwargs): if limits is None: limits = np.array([snippets.min(), snippets.max()]) * 1.1 edges = np.linspace(*limits, nbins + 1) centers = edges[:-1] + (edges[1] - edges[0]) / 2 insets = [] for ax, snippet in zip(axes, snippets.T): pdf, _ = np.histogram(snippet, edges, density=True) inset = ax.inset_axes(inset_bounds) handle = inset.plot(pdf, centers, **kwargs)[0] set_clip_box(handle, inset, bounds=[[-0.05, 0], [1.05, 1]]) handle = inset.fill_betweenx(centers, pdf.min(), pdf, where=(centers > thresh), **fill_kwargs) set_clip_box(handle, inset, bounds=[[-0.05, 0], [1.05, 1]]) inset.set_xlim(0, pdf.max()) inset.set_ylim(ax.get_ylim()) inset.axis('off') insets.append(inset) return insets # GENERAL SETTINGS: example_file = 'Omocestus_rufipes_DJN_32-40s724ms-48s779ms' data_path = search_files(example_file, incl='noise', dir='../data/inv/thresh_lp/')[0] stages = ['conv', 'bi', 'feat'] load_kwargs = dict( files=stages, keywords=['scales', 'snip', 'measure', 'thresh'] ) save_path = '../figures/fig_invariance_thresh_lp_single.pdf' exclude_zero = True # GRAPH SETTINGS: fig_kwargs = dict( figsize=(32/2.54, 32/2.54), ) super_grid_kwargs = dict( nrows=None, ncols=3, wspace=0, hspace=0, left=0, right=1, bottom=0, top=1, ) input_rows = 1 snip_rows = 2 subfig_specs = dict( input=(slice(input_rows), slice(-1)), snip=[np.array([input_rows, input_rows + snip_rows]), slice(-1)], big=(slice(None), -1), ) snip_col_shift = -0.07 snip_grid_kwargs = dict( nrows=len(stages), ncols=None, wspace=0.3, hspace=0, left=0.23 - snip_col_shift, right=0.93, bottom=0.15, top=0.95, height_ratios=[4, 1, 2] ) input_grid_kwargs = dict( nrows=1, ncols=None, wspace=snip_grid_kwargs['wspace'], hspace=0, left=snip_grid_kwargs['left'], right=snip_grid_kwargs['right'], bottom=0.15, top=0.75, ) big_grid_kwargs = dict( nrows=2, ncols=1, wspace=0, hspace=0.15, left=0.2, right=0.96, bottom=0.05, top=0.99 ) dist_inset_bounds = [1.02, 0, 0.2, 1] # PLOT SETTINGS: fs = dict( lab_norm=16, lab_tex=20, letter=22, tit_norm=16, tit_tex=20, bar=16, ) colors = load_colors('../data/stage_colors.npz') shade_factors = [0.2, -0.2] lw = dict( inv=1.5, conv=1.5, bi=0.1, feat=3, big=4, thresh=1.5, kern=2.5, plateau=1.5, ) xlabels = dict( alpha='scale $\\alpha$', sigma='$\\sigma_{\\text{adapt}}$', ) ylabels = dict( inv='$x_{\\text{adapt}}$\n$[\\text{dB}]$', conv='$c$\n$[\\text{dB}]$', bi='$b$', feat='$f$', big='$\\mu_f$', ) xlab_alpha_kwargs = dict( y=0.5, fontsize=fs['lab_norm'], ha='center', va='bottom', ) xlab_sigma_kwargs = dict( y=0, fontsize=fs['lab_tex'], ha=xlab_alpha_kwargs['ha'], va='bottom', ) ylab_snip_kwargs = dict( x=0.1, fontsize=fs['lab_tex'], rotation=0, ha='center', va='center', ) ylab_super_kwargs = dict( x=0, fontsize=fs['lab_tex'], ha='left', va='center', ) ylab_big_kwargs = dict( x=0, fontsize=fs['lab_tex'], ha='center', va='top', ) ypad = dict( inv=0.05, conv=0.05, big=0.1 ) yloc = dict( inv=(2, 200), conv=(0.02, 2), bi=(1, 1), feat=(1, 1), big=0.2, ) title_kwargs = dict( x=0.5, yref=1, ha='center', va='top', fontsize=fs['tit_norm'], ) letter_snip_kwargs = dict( x=0, y=1, ha='left', va='top', fontsize=fs['letter'], ) letter_big_kwargs = dict( xref=0, y=1, ha='left', va='top', fontsize=fs['letter'], ) kern_kwargs = dict( c='k', lw=lw['kern'], ) dist_kwargs = dict( c='k', lw=1, ) dist_fill_kwargs = dict( color=colors['bi'], lw=0.1, ) thresh_kwargs = dict( color='k', lw=lw['thresh'], ls='--', zorder=3, ) bar_time = 0.1 bar_kwargs = dict( dur=bar_time, y0=-0.5, y1=-0.35, xshift=1, color='k', lw=0, clip_on=False, text_pos=(-0.1, 0.5), text_str=f'${int(1000 * bar_time)}\\,\\text{{ms}}$', text_kwargs=dict( fontsize=fs['bar'], ha='right', va='center', ) ) leg_kwargs = dict( ncols=2, loc='center', bbox_to_anchor=(0, 0.95, 1, 0.05), frameon=False, fontsize=fs['tit_norm'], handlelength=1.5, columnspacing=1, ) cap_kwargs = dict( color='k', alpha=0.5, lw=0, zorder=5, ) plateau_settings = dict( low=0.05, high=0.95, first=True, last=True, condense=None, ) plateau_line_kwargs = dict( lw=lw['plateau'], ls='--', zorder=1, ) plateau_dot_kwargs = dict( marker='o', markersize=8, markeredgewidth=1, clip_on=False, ) zoom_rel = np.array([0.5, 0.515]) # SUBSET SETTINGS: kern_specs = np.array([ [1, 0.008], [2, 0.004], [3, 0.002], ])[np.array([1])] # PREPARATION: # Get saturation level of invariant envelope from log-hp analysis: inv_path = search_files(example_file, dir='../data/inv/log_hp/')[0] sigma_cap = load_data(inv_path, files='measure_inv')[0]['measure_inv'][-1] # EXECUTION: print(f'Processing {data_path}') # Load invariance data: noise_data, config = load_data(data_path, **load_kwargs) pure_data, _ = load_data(data_path.replace('noise', 'pure'), **load_kwargs) # Unpack shared variables: scales = noise_data['scales'] plot_scales = noise_data['example_scales'] thresh_rel = noise_data['thresh_rel'] thresh_abs = noise_data['thresh_abs'] # Reduce to kernel subset and crop to zoom frame: t_full = np.arange(noise_data['snip_conv'].shape[0]) / config['env_rate'] zoom_abs = zoom_rel * t_full[-1] zoom_inds = (t_full >= zoom_abs[0]) & (t_full <= zoom_abs[1]) kern_ind = find_kern_specs(config['k_specs'], kerns=kern_specs)[0] noise_data['snip_inv'] = noise_data['snip_inv'][zoom_inds, :] noise_data['snip_conv'] = noise_data['snip_conv'][zoom_inds, kern_ind, :] noise_data['snip_bi'] = noise_data['snip_bi'][zoom_inds, kern_ind, :, :] noise_data['snip_feat'] = noise_data['snip_feat'][zoom_inds, kern_ind, :, :] noise_data['measure_feat'] = noise_data['measure_feat'][:, kern_ind, :] pure_data['measure_feat'] = pure_data['measure_feat'][:, kern_ind, :] config['kernels'] = config['kernels'][:, kern_ind] thresh_abs = thresh_abs[:, kern_ind] t_full = np.arange(noise_data['snip_conv'].shape[0]) / config['env_rate'] if exclude_zero: # Exclude zero scale: inds = scales > 0 scales = scales[inds] noise_data['measure_inv'] = noise_data['measure_inv'][inds] noise_data['measure_feat'] = noise_data['measure_feat'][inds, :] pure_data['measure_feat'] = pure_data['measure_feat'][inds, :] # Get threshold-specific colors: factors = np.linspace(*shade_factors, thresh_rel.size) shaded = dict( conv=shade_colors(colors['conv'], factors), bi=shade_colors(colors['bi'], factors), feat=shade_colors(colors['feat'], factors), ) # Adjust grid parameters to loaded data: super_grid_kwargs['nrows'] = snip_rows * thresh_rel.size + input_rows input_grid_kwargs['ncols'] = plot_scales.size snip_grid_kwargs['ncols'] = plot_scales.size # Prepare overall graph: fig = plt.figure(**fig_kwargs) super_grid = fig.add_gridspec(**super_grid_kwargs) # Prepare input snippet axes: input_subfig = fig.add_subfigure(super_grid[subfig_specs['input']]) input_axes = add_snip_axes(input_subfig, input_grid_kwargs).ravel() input_axes[0].yaxis.set_major_locator(plt.MultipleLocator(yloc['inv'][0])) input_axes[1].yaxis.set_major_locator(plt.MultipleLocator(yloc['inv'][1])) ylabel(input_axes[0], ylabels['inv'], transform=input_subfig.transSubfigure, **ylab_snip_kwargs) for ax, scale in zip(input_axes, plot_scales): title_subplot(ax, f'$\\alpha={strip_zeros(scale)}$', ref=input_subfig, **title_kwargs) letter_subplot(input_subfig, 'a', **letter_snip_kwargs) # Prepare snippet axes: snip_subfigs, snip_axes = [], [] for i in range(thresh_rel.size): subfig_spec = subfig_specs['snip'].copy() subfig_spec[0] = slice(*(subfig_spec[0] + i * snip_rows)) snip_subfig = fig.add_subfigure(super_grid[*subfig_spec]) axes = add_snip_axes(snip_subfig, snip_grid_kwargs) low_box = axes[-1, 0].get_position() high_box = axes[0, 0].get_position() [hide_axis(ax, 'left') for ax in axes[1:, 1]] super_ylabel(f'$\\Theta={strip_zeros(thresh_rel[i])}\\cdot\\sigma_{{\\eta}}$', snip_subfig, axes[-1, 0], axes[0, 0], **ylab_super_kwargs) for (ax1, ax2), stage in zip(axes[:, :2], stages): ax1.yaxis.set_major_locator(plt.MultipleLocator(yloc[stage][0])) ax2.yaxis.set_major_locator(plt.MultipleLocator(yloc[stage][1])) ylabel(ax1, ylabels[stage], transform=snip_subfig.transSubfigure, **ylab_snip_kwargs) if i == thresh_rel.size - 1: axes[-1, -1].set_xlim(t_full[0], t_full[-1]) time_bar(axes[-1, -1], **bar_kwargs) snip_subfigs.append(snip_subfig) snip_axes.append(axes) letter_subplots(snip_subfigs, 'bcd', **letter_snip_kwargs) # Prepare analysis axes: big_subfig = fig.add_subfigure(super_grid[subfig_specs['big']]) big_grid = big_subfig.add_gridspec(**big_grid_kwargs) alpha_ax = big_subfig.add_subplot(big_grid[0, 0]) alpha_ax.set_xlim(scales[0], scales[-1]) alpha_ax.set_xscale('symlog', linthresh=scales[scales > 0][0], linscale=0.5) ylimits(pure_data['measure_feat'], alpha_ax, minval=0, pad=ypad['big']) alpha_ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['big'])) xlabel(alpha_ax, xlabels['alpha'], **xlab_alpha_kwargs, transform=big_subfig) ylabel(alpha_ax, ylabels['big'], transform=big_subfig.transSubfigure, **ylab_big_kwargs) letter_subplot(alpha_ax, 'e', ref=big_subfig, **letter_big_kwargs) sigma_ax = big_subfig.add_subplot(big_grid[1, 0]) sigma_ax.set_xlim(1, noise_data['measure_inv'].max()) sigma_ax.set_xscale('symlog', linthresh=scales[scales > 0][0], linscale=0.5) ylimits(pure_data['measure_feat'], sigma_ax, minval=0, pad=ypad['big']) sigma_ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['big'])) xlabel(sigma_ax, xlabels['sigma'], **xlab_sigma_kwargs, transform=big_subfig) ylabel(sigma_ax, ylabels['big'], transform=big_subfig.transSubfigure, **ylab_big_kwargs) letter_subplot(sigma_ax, 'f', ref=big_subfig, **letter_big_kwargs) # Plot intensity-adapted snippets: plot_snippets(input_axes, t_full, noise_data['snip_inv'], ypad=ypad['inv'], c=colors['inv'], lw=lw['inv']) ylimits(noise_data['snip_inv'][:, 0], input_axes[0], pad=ypad['inv']) # Indicate kernel waveform over 1st intensity-adapted snippet: input_axes[0].plot(config['k_times'] + 0.5 * t_full[-1], config['kernels'], **kern_kwargs) # Plot representation snippets per threshold: for i, (subfig, axes) in enumerate(zip(snip_subfigs, snip_axes)): dist_fill_kwargs['color'] = shaded['bi'][i] # Plot kernel response snippets: plot_snippets(axes[0, :], t_full, noise_data['snip_conv'], thresh=thresh_abs[i], ypad=ypad['conv'], fill_kwargs=dist_fill_kwargs, c=shaded['conv'][i], lw=lw['conv']) ylim_zoom = ylimits(noise_data['snip_conv'][:, 0], axes[0, 0], pad=ypad['conv'], maxval=thresh_abs[-1]) # Indicate absolute threshold value: handle = axes[0, 0].axhline(thresh_abs[i], **thresh_kwargs) set_clip_box(handle, axes[0, 0], bounds=[[0, 0], [1, 1.05]]) # Plot kernel response distributions: side_distributions(axes[0, :1], noise_data['snip_conv'][:, :1], dist_inset_bounds, thresh_abs[i], nbins=50, limits=ylim_zoom, fill_kwargs=dist_fill_kwargs, **dist_kwargs) side_distributions(axes[0, 1:], noise_data['snip_conv'][:, 1:], dist_inset_bounds, thresh_abs[i], nbins=50, fill_kwargs=dist_fill_kwargs, **dist_kwargs) # Plot binary snippets: plot_bi_snippets(axes[1, :], t_full, noise_data['snip_bi'][:, :, i], color=shaded['bi'][i], lw=lw['bi']) # Plot feature snippets: handles = plot_snippets(axes[2, :], t_full, noise_data['snip_feat'][:, :, i], ymin=0, ymax=1, c=shaded['feat'][i], lw=lw['feat']) [set_clip_box(h[0], ax, bounds=[[0, -0.05], [1, 1.05]]) for h, ax in zip(handles, axes[2, :])] # Get saturation: saturation_inds = [] for i in range(thresh_rel.size): ind = get_saturation(noise_data['measure_feat'][:, i], **plateau_settings)[1] saturation_inds.append(ind) # Plot pure-song analysis results over alpha: handles = alpha_ax.plot(scales, pure_data['measure_feat'], lw=lw['big'], ls='dotted') [h.set_color(c) for h, c in zip(handles, shaded['feat'])] # Plot noise-song analysis results over alpha: handles = alpha_ax.plot(scales, noise_data['measure_feat'], lw=lw['big']) [h.set_color(c) for h, c in zip(handles, shaded['feat'])] # Indicate threshold-specific saturation: for i, ind in enumerate(saturation_inds): color = shaded['feat'][i] alpha_ax.plot(scales[ind], 0, c='w', alpha=1, zorder=5.5, **plateau_dot_kwargs, transform=alpha_ax.get_xaxis_transform()) alpha_ax.plot(scales[ind], 0, mfc=color, mec='k', alpha=0.75, zorder=6, **plateau_dot_kwargs, transform=alpha_ax.get_xaxis_transform()) alpha_ax.vlines(scales[ind], alpha_ax.get_ylim()[0], noise_data['measure_feat'][ind, i], color=color, **plateau_line_kwargs) # Add proxy legend: h1 = alpha_ax.plot([], [], c='k', lw=lw['big'], label='$\\alpha\\cdot s(t) + \\eta(t)$')[0] h2 = alpha_ax.plot([], [], c='k', lw=lw['big'], ls='dotted', label='$\\alpha\\cdot s(t)$')[0] alpha_ax.legend(handles=[h1, h2], **leg_kwargs) # Plot noise-song analysis results over sigma: handles = sigma_ax.plot(noise_data['measure_inv'], noise_data['measure_feat'], lw=lw['big']) [h.set_color(c) for h, c in zip(handles, shaded['feat'])] # Indicate threshold-specific saturation: for i, ind in enumerate(saturation_inds): color = shaded['feat'][i] sigma_ax.plot(scales[ind], 0, c='w', alpha=1, zorder=5.5, **plateau_dot_kwargs, transform=sigma_ax.get_xaxis_transform()) sigma_ax.plot(scales[ind], 0, mfc=color, mec='k', alpha=0.75, zorder=6, **plateau_dot_kwargs, transform=sigma_ax.get_xaxis_transform()) sigma_ax.vlines(scales[ind], sigma_ax.get_ylim()[0], noise_data['measure_feat'][ind, i], color=color, **plateau_line_kwargs) # Indicate sigma range capped by log-hp mechanism: sigma_ax.axvspan(sigma_cap, sigma_ax.get_xlim()[1], **cap_kwargs) if save_path is not None: fig.savefig(save_path) plt.show() print('Done.') embed()