import plotstyle_plt import string import numpy as np import matplotlib.pyplot as plt from thunderhopper.filetools import search_files from thunderhopper.modeltools import load_data from thunderhopper.filtertools import find_kern_specs from color_functions import load_colors from plot_functions import hide_ticks, ylabel, super_xlabel, letter_subplots,\ ylimits, title_subplot from misc_functions import exclude_zero_scale, reduce_kernel_set from IPython import embed # GENERAL SETTINGS: target_species = [ 'Chorthippus_biguttulus', 'Chorthippus_mollis', 'Chrysochraon_dispar', 'Euchorthippus_declivus', 'Gomphocerippus_rufus', 'Omocestus_rufipes', 'Pseudochorthippus_parallelus', ][5] modes = [ 'unnormed', 'norm-base', 'norm-min', 'norm-max', ] full_folder = '../data/inv/full/condensed/' short_folder = '../data/inv/short/condensed/' save_path = '../figures/fig_invariance_full_short.pdf' load_kwargs = dict( files=['scales', 'mean_feat', 'sd_feat'], keywords=['thresh'], ) # ANALYSIS SETTINGS: exclude_zero = True scale_subset_kwargs = dict( combis=[['mean', 'sd'], ['feat']], ) kern_subset_kwargs = dict( combis=[['mean', 'sd'], ['feat']], keys=['thresh_abs'], ) thresh_rel = np.array([0, 0.5, 1, 1.5, 2, 2.5, 3]) percentiles = np.array([ [25, 75], # [0, 100], ]) # SUBSET SETTINGS: types = np.array([1, 2, 3]) # types = [1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8, 9, -9, 10, -10] sigmas = np.array([0.001, 0.002, 0.004, 0.008, 0.016, 0.032]) # sigmas = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032] kernels = None reduce_kernels = any(var is not None for var in [kernels, types, sigmas]) # GRAPH SETTINGS: fig_kwargs = dict( figsize=(32/2.54, 32/2.54), ) super_grid_kwargs = dict( nrows=1, ncols=2, wspace=0, hspace=0, left=0, right=1, bottom=0, top=1, ) subfig_specs = dict( full=(0, 0), short=(0, 1), ) col_width = 0.85 col_rest = 1 - col_width full_grid_kwargs = dict( nrows=len(modes), ncols=1, wspace=0, hspace=0.1, left=col_rest, right=1, bottom=0.05, top=0.9 ) short_grid_kwargs = dict( nrows=len(modes), ncols=1, wspace=full_grid_kwargs['wspace'], hspace=full_grid_kwargs['hspace'], left=col_rest / 2, right=1 - col_rest / 2, bottom=full_grid_kwargs['bottom'], top=full_grid_kwargs['top'] ) # PLOT SETTINGS: fs = dict( lab_norm=16, lab_tex=20, letter=22, tit_norm=16, tit_tex=20, bar=16, ) colors = load_colors('../data/stage_colors.npz') feat_colors = load_colors('../data/feat_colors_all.npz') lw = dict( feat=3, plateau=1.5 ) line_kwargs = dict( lw=lw['feat'] ) fill_kwargs = dict( alpha=0.15 ) xlabels = dict( super='scale $\\alpha$', ) ylabels = { 'unnormed': '$\\mu_{f_i}$', 'norm-base': '$\\mu_{f_i}\\,/\\,\\mu_{f_i}\\,[\\,\\eta\\,]$', 'norm-min': '$\\mu_{f_i}\\,/\\,\\min\\,[\\,\\mu_{f_i}\\,]$', 'norm-max': '$\\mu_{f_i}\\,/\\,\\max\\,[\\,\\mu_{f_i}\\,]$' } xlab_kwargs = dict( y=0, fontsize=fs['lab_norm'], ha='center', va='bottom', ) ylab_kwargs = dict( x=0, fontsize=fs['lab_tex'], ha='center', va='top', ) ylims = { 'unnormed': [0, 1], 'norm-base': [0, None], 'norm-min': [0, None], 'norm-max': [0, 1] } yloc = { 'unnormed': 0.5, 'norm-base': 0.5, 'norm-min': 0.5, 'norm-max': 0.5 } title_kwargs = dict( x=0.5, y=1, ha='center', va='bottom', fontsize=fs['tit_norm'], ) titles = dict( full='Including $x_\\text{dB}$', short='Excluding $x_\\text{dB}$', ) letter_kwargs = dict( xref=0.01, y=1, ha='left', va='center', fontsize=fs['letter'], ) letters = dict( full=string.ascii_lowercase[0::2], short=string.ascii_lowercase[1::2], ) plateau_settings = dict( low=0.05, high=0.95, first=True, last=True, condense=None, ) plateau_line_kwargs = dict( lw=lw['plateau'], ls='--', zorder=1, ) plateau_dot_kwargs = dict( marker='o', markersize=8, markeredgewidth=1, clip_on=False, ) # EXECUTION: # Prepare overall graph: fig = plt.figure(**fig_kwargs) super_grid = fig.add_gridspec(**super_grid_kwargs) # Prepare full analysis axes: full_subfig = fig.add_subfigure(super_grid[*subfig_specs['full']]) full_grid = full_subfig.add_gridspec(**full_grid_kwargs) full_axes = np.zeros((len(modes),), dtype=object) for i, mode in enumerate(modes): full_axes[i] = full_subfig.add_subplot(full_grid[i, 0]) # full_axes[i].yaxis.set_major_locator(plt.MultipleLocator(yloc[mode])) full_axes[i].set_xscale('symlog', linthresh=0.01, linscale=0.5) ylabel(full_axes[i], ylabels[mode], transform=full_subfig.transSubfigure, **ylab_kwargs) if i == 0: title_subplot(full_axes[i], titles['full'], **title_kwargs) if i < full_grid_kwargs['nrows'] - 1: hide_ticks(full_axes[i], 'bottom') letter_subplots(full_axes, letters['full'], ref=full_subfig, **letter_kwargs) # Prepare short analysis axes: short_subfig = fig.add_subfigure(super_grid[*subfig_specs['short']]) short_grid = short_subfig.add_gridspec(**short_grid_kwargs) short_axes = np.zeros((len(modes),), dtype=object) for i, mode in enumerate(modes): short_axes[i] = short_subfig.add_subplot(short_grid[i, 0]) # short_axes[i].yaxis.set_major_locator(plt.MultipleLocator(yloc[mode])) short_axes[i].set_xscale('symlog', linthresh=0.01, linscale=0.5) hide_ticks(short_axes[i], 'left') if i == 0: title_subplot(short_axes[i], titles['short'], **title_kwargs) if i < short_grid_kwargs['nrows'] - 1: hide_ticks(short_axes[i], 'bottom') letter_subplots(short_axes, letters['short'], ref=short_subfig, **letter_kwargs) super_xlabel(xlabels['super'], fig, full_axes[-1], short_axes[-1], left_fig=full_subfig, right_fig=short_subfig, **xlab_kwargs) # Run through normalization modes: for mode, full_ax, short_ax in zip(modes, full_axes, short_axes): # Load invariance data: full_path = search_files(target_species, incl=mode, dir=full_folder)[0] short_path = search_files(target_species, incl=mode, dir=short_folder)[0] full_data, config = load_data(full_path, **load_kwargs) short_data, _ = load_data(short_path, **load_kwargs) # Reduce datasets: if reduce_kernels: kern_inds = find_kern_specs(config['k_specs'], kernels, types, sigmas) full_data = reduce_kernel_set(full_data, kern_inds, **kern_subset_kwargs) short_data = reduce_kernel_set(short_data, kern_inds, **kern_subset_kwargs) config['k_specs'] = config['k_specs'][kern_inds, :] config['kernels'] = config['kernels'][:, kern_inds] if exclude_zero: full_data = exclude_zero_scale(full_data, **scale_subset_kwargs) short_data = exclude_zero_scale(short_data, **scale_subset_kwargs) # Average over recordings: full_measure = full_data['mean_feat'].mean(axis=-1) short_measure = short_data['mean_feat'].mean(axis=-1) # Condense over kernels: full_median = np.nanmedian(full_measure, axis=1) full_spread = np.nanpercentile(full_measure, percentiles, axis=1) short_median = np.nanmedian(short_measure, axis=1) short_spread = np.nanpercentile(short_measure, percentiles, axis=1) # Determine shared ylims: if None in ylims[mode]: min_val, max_val = ylims[mode] full_limits = ylimits(full_median, minval=min_val, maxval=max_val) short_limits = ylimits(short_median, minval=min_val, maxval=max_val) ylims[mode] = [min(full_limits[0], short_limits[0]), max(full_limits[1], short_limits[1])] if np.inf in ylims[mode]: embed() # Plot full analysis results: for i, thresh in enumerate(thresh_rel): full_ax.plot(full_data['scales'], full_median[:, i], lw=lw['feat']) for spread in full_spread[:, :, :, i]: full_ax.fill_between(full_data['scales'], *spread, **fill_kwargs) full_ax.set_xlim(full_data['scales'][0], full_data['scales'][-1]) full_ax.set_ylim(ylims[mode]) # Plot short analysis results: for i, thresh in enumerate(thresh_rel): short_ax.plot(short_data['scales'], short_median[:, i], lw=lw['feat']) for spread in short_spread[:, :, :, i]: short_ax.fill_between(short_data['scales'], *spread, **fill_kwargs) short_ax.set_xlim(short_data['scales'][0], short_data['scales'][-1]) short_ax.set_ylim(ylims[mode]) plt.show()