import plotstyle_plt import numpy as np import matplotlib.pyplot as plt from thunderhopper.modeltools import load_data from thunderhopper.filetools import search_files from thunderhopper.filtertools import find_kern_specs from misc_functions import shorten_species, x_dist, y_dist, get_saturation from color_functions import load_colors from plot_functions import reorder_by_sd, ylabel, super_xlabel, super_ylabel,\ title_subplot, assign_colors, strip_zeros, hide_axis,\ hide_ticks from IPython import embed # GENERAL SETTINGS: target_species = [ # 'Chorthippus_biguttulus', # 'Chorthippus_mollis', # 'Chrysochraon_dispar', # 'Euchorthippus_declivus', 'Gomphocerippus_rufus', 'Omocestus_rufipes', 'Pseudochorthippus_parallelus', ] example_files = { 'Chorthippus_biguttulus': 'Chorthippus_biguttulus_GBC_94-17s73.1ms-19s977ms', 'Chorthippus_mollis': 'Chorthippus_mollis_DJN_41_T28C-46s4.58ms-1m15s697ms', 'Chrysochraon_dispar': 'Chrysochraon_dispar_DJN_26_T28C_DT-32s134ms-34s432ms', 'Euchorthippus_declivus': 'Euchorthippus_declivus_FTN_79-2s167ms-2s563ms', 'Gomphocerippus_rufus': 'Gomphocerippus_rufus_FTN_91-3-884ms-10s427ms', 'Omocestus_rufipes': 'Omocestus_rufipes_DJN_32-40s724ms-48s779ms', 'Pseudochorthippus_parallelus': 'Pseudochorthippus_parallelus_GBC_88-6s678ms-9s32.3ms' } search_path = '../data/inv/full/' save_path = '../figures/fig_invariance_cross_species_thresh_appendix.pdf' # ANALYSIS SETTINGS: exclude_zero = True thresh_rel = np.array([0, 0.5, 1, 1.5, 2, 2.5, 3]) # SUBSET SETTINGS: types = np.array([1, -1, 2, -2, 3, -3, 4, -4]) # types = [1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8, 9, -9, 10, -10] sigmas = np.array([0.001, 0.002, 0.004, 0.008, 0.016]) # sigmas = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032] kernels = None reduce_kernels = any(var is not None for var in [kernels, types, sigmas]) # GRAPH SETTINGS: fig_kwargs = dict( figsize=(32/2.54, 32/2.54), nrows=thresh_rel.size, ncols=len(target_species), sharex=True, sharey=True, gridspec_kw=dict( wspace=0.2, hspace=0.75, left=0.1, right=0.95, bottom=0.08, top=0.98, ) ) inset_x_bounds = [0, -0.5, 1, 0.4] inset_y_bounds = [1.01, 0, 0.1, 1] # PLOT SETTINGS: fs = dict( lab_norm=16, lab_tex=20, letter=22, tit_norm=16, tit_tex=20, bar=16, ) lw = dict( swarm=1, single=3, dist=2, ) base_color = load_colors('../data/stage_colors.npz')['feat'] kern_colors = load_colors('../data/feat_colors_subset.npz') median_kwargs = dict( c='k', lw=lw['single'], ls='--', zorder=3 ) xlab = 'scale $\\alpha$' xlab_kwargs = dict( y=0, fontsize=fs['lab_norm'], ha='center', va='bottom' ) ylab = '$\\mu_{f_i}$' ylab_super_kwargs = dict( x=0, fontsize=fs['lab_norm'], ha='left', va='center' ) ylab_ax_kwargs = dict( x=0.03, fontsize=fs['lab_norm'], ha='center', va='top' ) yloc = 0.5 title_kwargs = dict( x=0.5, yref=1, fontsize=fs['tit_norm'], ha='center', va='top', fontstyle='italic' ) plateau_settings = dict( low=0.05, high=0.95, first=True, last=True, condense=None, ) plateau_dot_kwargs = dict( marker='o', mfc=base_color, mec='k', ms=8, mew=1, clip_on=False, zorder=6 ) x_dist_kwargs = dict( line_kwargs = dict( c=base_color, lw=lw['dist'], ), fill_kwargs = dict( color=base_color, alpha=1, ), nbins=100, log=True, ) y_dist_kwargs = dict( line_kwargs = dict( c=base_color, lw=lw['dist'], ), fill_kwargs = dict( color=base_color, alpha=1, ), edges=np.linspace(0, 1, 101), log=False, ) # EXECUTION: # Prepare graph: fig, axes = plt.subplots(**fig_kwargs) axes[0, 0].set_ylim(0, 1) axes[0, 0].yaxis.set_major_locator(plt.MultipleLocator(yloc)) super_xlabel(xlab, fig, axes[-1, 0], axes[-1, -1], **xlab_kwargs) super_ylabel(ylab, fig, axes[0, 0], axes[-1, 0], **ylab_super_kwargs) for ax, species in zip(axes[0, :], target_species): title_subplot(ax, shorten_species(species), ref=fig, **title_kwargs) for ax, thresh in zip(axes[:, 0], thresh_rel): title = f'$\\Theta_i\\,=\\,{strip_zeros(thresh)}\\,\\cdot\\,\\sigma_{{\\eta_i}}$' ylabel(ax, title, transform=fig.transFigure, **ylab_ax_kwargs) for ax in axes[-1, :]: hide_ticks(ax, 'bottom') # Run through species: for i, species in enumerate(target_species): print(f'Processing {species}...') # Load invariance data: path = search_files(example_files[species], dir=search_path)[0] data, config = load_data(path, ['scales', 'measure_feat', 'thresh_rel']) scales, measure = data['scales'], data['measure_feat'] # Reduce data: if exclude_zero: inds = np.nonzero(scales > 0)[0] scales, measure = scales[inds], measure[inds, ...] if reduce_kernels: kern_inds = find_kern_specs(config['k_specs'], kernels, types, sigmas) measure = measure[:, kern_inds, :] config['kernels'] = config['kernels'][:, kern_inds] config['k_specs'] = config['k_specs'][kern_inds, :] if i == 0: # Update settings: x_dist_kwargs['edges'] = np.geomspace(scales[scales > 0][0], scales[-1], x_dist_kwargs['nbins'] + 1) symlog_kwargs = dict(linthresh=scales[scales > 0][0], linscale=0.5) # Run through thresholds: for j in range(thresh_rel.size): ax = axes[j, i] # Plot swarm of feature-specific intensity curves: handles = ax.plot(scales, measure[:, :, j], lw=lw['swarm']) assign_colors(handles, config['k_specs'][:, 0], kern_colors) reorder_by_sd(handles, measure[:, :, j]) # Plot single compressed intensity curve: compressed = np.median(measure[:, :, j], axis=1) ax.plot(scales, compressed, **median_kwargs) # Plot distribution of saturation levels: inset = ax.inset_axes(inset_y_bounds) inset.set_ylim(0, 1) inset.axis('off') y_dist(inset, measure[-1, :, j], **y_dist_kwargs) # Plot distribution of saturation points: crit_inds = np.array(get_saturation(measure[:, :, j], **plateau_settings)[1]) if np.isnan(crit_inds).sum(): print(f'WARNING: No saturation points found for {species} at threshold {thresh_rel[j]}') crit_inds = crit_inds[~np.isnan(crit_inds)].astype(int) crit_scales = scales[crit_inds] inset = ax.inset_axes(inset_x_bounds) inset.set_xlim(scales[0], scales[-1]) inset.set_xscale('symlog', **symlog_kwargs) hide_axis(inset, 'left') if j < thresh_rel.size - 1: hide_ticks(inset, 'bottom') x_dist(inset, crit_scales, **x_dist_kwargs) if j > 0: # Plot single saturation point: crit_ind = get_saturation(compressed, **plateau_settings)[1] crit_scale = scales[crit_ind] inset.plot(crit_scale, 0, **plateau_dot_kwargs) # Posthocs: axes[0, 0].set_xscale('symlog', **symlog_kwargs) axes[0, 0].set_xlim(scales[0], scales[-1]) if save_path is not None: fig.savefig(save_path) print('Done.') plt.show()