import plotstyle_plt import numpy as np import matplotlib.pyplot as plt from thunderhopper.filetools import search_files from thunderhopper.modeltools import load_data from plot_functions import ylabel, ylimits, super_xlabel, title_subplot, time_bar from color_functions import load_colors, shade_colors from misc_functions import shorten_species from IPython import embed # GENERAL SETTINGS: mode = ['pure', 'noise'][1] target_species = [ 'Chorthippus_biguttulus', 'Chorthippus_mollis', 'Chrysochraon_dispar', # 'Euchorthippus_declivus', 'Gomphocerippus_rufus', 'Omocestus_rufipes', 'Pseudochorthippus_parallelus', ] data_path = '../data/inv/thresh_lp/condensed/' save_path = f'../figures/fig_invariance_thresh-lp_{mode}_appendix.pdf' # ANALYSIS SETTINGS: exclude_zero = True # SUBSET SETTINGS: thresh_rel = np.array([0.5, 1, 3])[0] kern_specs = np.array([ [1, 0.008], [2, 0.004], [3, 0.002], ])[np.array([0, 1, 2])] n_kernels = kern_specs.shape[0] # GRAPH SETTINGS: fig_kwargs = dict( figsize=(32/2.54, 16/2.54), nrows=n_kernels, ncols=len(target_species), sharex=True, sharey=True, gridspec_kw=dict( wspace=0.4, hspace=0.2, left=0.07, right=0.98, bottom=0.1, top=0.95, ) ) # PLOT SETTINGS: species_colors = load_colors('../data/species_colors.npz') kern_shades = [0, 0.75] kern_colors = shade_colors((0., 0., 0.), np.linspace(*kern_shades, n_kernels)) line_kwargs = dict( lw=2, alpha=0.5, zorder=2, ) fill_kwargs = dict( alpha=0.3, zorder=1, ) mean_kwargs = dict( # c=(0.5,) * 3, lw=2, alpha=1, zorder=3, ls='--' ) mean_colors = { 'Chorthippus_biguttulus': (1,) * 3, 'Chorthippus_mollis': (0,) * 3, 'Chrysochraon_dispar': (0,) * 3, 'Euchorthippus_declivus': (0,) * 3, 'Gomphocerippus_rufus': (0,) * 3, 'Omocestus_rufipes': (0,) * 3, 'Pseudochorthippus_parallelus': (1,) * 3, } kern_kwargs = dict( lw=2, ) inset_bounds = [0.05, 0.6, 0.3, 0.25] kern_bar_time = 0.05 kern_bar_kwargs = dict( dur=kern_bar_time, y0=0.1, y1=0.2, color='k', lw=0, clip_on=False, text_pos=(0.5, -1), text_str=f'${int(kern_bar_time * 1000)}\\,\\text{{ms}}$', text_kwargs=dict( fontsize=12, ha='center', va='top', ) ) xlab = 'scale $\\alpha$' ylabs = [f'$\\mu_{{f_{i}}}$' for i in range(1, n_kernels + 1)] xlab_kwargs = dict( y=0, fontsize=16, ha='center', va='bottom', ) ylab_kwargs = dict( x=0, fontsize=20, ha='center', va='top', ) title_kwargs = dict( x=0.5, yref=0.99, ha='center', va='top', fontsize=16, fontstyle='italic', ) letter_kwargs = dict( x=0.005, y=0.99, fontsize=22, ha='left', va='top', ) # Prepare graph: fig, axes = plt.subplots(**fig_kwargs) axes[0, 0].set_xscale('log') axes[0, 0].set_ylim(0, 1) axes[0, 0].yaxis.set_major_locator(plt.MultipleLocator(0.5)) super_xlabel(xlab, fig, axes[-1, 0], axes[-1, -1], **xlab_kwargs) insets = [] for ax, ylab in zip(axes[:, 0], ylabs): ylabel(ax, ylab, **ylab_kwargs, transform=fig.transFigure) insets.append(ax.inset_axes(inset_bounds)) # Run through species: for i, (species, spec_axes) in enumerate(zip(target_species, axes.T)): title_subplot(spec_axes[0], shorten_species(species), ref=fig, **title_kwargs) # Load species data: path = search_files(species, incl=[mode, 'unnormed'], dir=data_path)[0] data, config = load_data(path, files=['scales', 'mean_feat', 'sd_feat', 'thresh_rel']) scales = data['scales'] means = data['mean_feat'] sds = data['sd_feat'] # Reduce to single threshold: ind = np.nonzero(data['thresh_rel'] == thresh_rel)[0][0] means = means[:, :, ind, :] sds = sds[:, :, ind, :] if exclude_zero: # Exclude zero scale: inds = scales > 0 scales = scales[inds] means = means[inds, :, :] sds = sds[inds, :, :] # Run through kernels: for j, (ax, inset) in enumerate(zip(spec_axes, insets)): if i == 0: # Indicate kernel waveform: inset.plot(config['k_times'], config['kernels'][:, j], c=kern_colors[j], **kern_kwargs) inset.set_xlim(config['k_times'][[0, -1]]) ylimits(config['kernels'], inset, pad=0.05) inset.set_title(rf'$k_{{{j+1}}}$', fontsize=15) if j == 0: time_bar(inset, **kern_bar_kwargs) inset.axis('off') # Plot recording-specific traces: for k in range(means.shape[-1]): ax.plot(scales, means[:, j, k], c=species_colors[species], **line_kwargs) spread = (means[:, j, k] - sds[:, j, k], means[:, j, k] + sds[:, j, k]) ax.fill_between(scales, *spread, color=species_colors[species], **fill_kwargs) # Plot kernel-specific mean trace: ax.plot(scales, means[:, j, :].mean(axis=-1), c=mean_colors[species], **mean_kwargs) # Save graph: fig.savefig(save_path) plt.show()