Files
paper_2025/python/fig_invariance_full-short.py
j-hartling e70d100655 Added loads of units in nearly all graphs.
Overhauled fig_invariance_full.pdf.
Added some legends, somewhere.
2026-04-28 19:43:05 +02:00

295 lines
8.4 KiB
Python

import plotstyle_plt
import string
import numpy as np
import matplotlib.pyplot as plt
from thunderhopper.filetools import search_files
from thunderhopper.modeltools import load_data
from thunderhopper.filtertools import find_kern_specs
from color_functions import load_colors
from plot_functions import hide_ticks, ylabel, super_xlabel, letter_subplots,\
ylimits, title_subplot
from misc_functions import exclude_zero_scale, reduce_kernel_set
from IPython import embed
# GENERAL SETTINGS:
target_species = [
'Chorthippus_biguttulus',
'Chorthippus_mollis',
'Chrysochraon_dispar',
'Euchorthippus_declivus',
'Gomphocerippus_rufus',
'Omocestus_rufipes',
'Pseudochorthippus_parallelus',
][5]
modes = [
'unnormed',
'norm-base',
'norm-min',
'norm-max',
]
full_folder = '../data/inv/full/condensed/'
short_folder = '../data/inv/short/condensed/'
save_path = '../figures/fig_invariance_full_short.pdf'
load_kwargs = dict(
files=['scales', 'mean_feat', 'sd_feat'],
keywords=['thresh'],
)
# ANALYSIS SETTINGS:
exclude_zero = True
scale_subset_kwargs = dict(
combis=[['mean', 'sd'], ['feat']],
)
kern_subset_kwargs = dict(
combis=[['mean', 'sd'], ['feat']],
keys=['thresh_abs'],
)
thresh_rel = np.array([0, 0.5, 1, 1.5, 2, 2.5, 3])
percentiles = np.array([
[25, 75],
# [0, 100],
])
# SUBSET SETTINGS:
types = np.array([1, 2, 3])
# types = [1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8, 9, -9, 10, -10]
sigmas = np.array([0.001, 0.002, 0.004, 0.008, 0.016, 0.032])
# sigmas = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
kernels = None
reduce_kernels = any(var is not None for var in [kernels, types, sigmas])
# GRAPH SETTINGS:
fig_kwargs = dict(
figsize=(32/2.54, 32/2.54),
)
super_grid_kwargs = dict(
nrows=1,
ncols=2,
wspace=0,
hspace=0,
left=0,
right=1,
bottom=0,
top=1,
)
subfig_specs = dict(
full=(0, 0),
short=(0, 1),
)
col_width = 0.85
col_rest = 1 - col_width
full_grid_kwargs = dict(
nrows=len(modes),
ncols=1,
wspace=0,
hspace=0.1,
left=col_rest,
right=1,
bottom=0.05,
top=0.9
)
short_grid_kwargs = dict(
nrows=len(modes),
ncols=1,
wspace=full_grid_kwargs['wspace'],
hspace=full_grid_kwargs['hspace'],
left=col_rest / 2,
right=1 - col_rest / 2,
bottom=full_grid_kwargs['bottom'],
top=full_grid_kwargs['top']
)
# PLOT SETTINGS:
fs = dict(
lab_norm=16,
lab_tex=20,
letter=22,
tit_norm=16,
tit_tex=20,
bar=16,
)
colors = load_colors('../data/stage_colors.npz')
feat_colors = load_colors('../data/feat_colors_all.npz')
lw = dict(
feat=3,
plateau=1.5
)
line_kwargs = dict(
lw=lw['feat']
)
fill_kwargs = dict(
alpha=0.15
)
xlabels = dict(
super='scale $\\alpha$',
)
ylabels = {
'unnormed': '$\\mu_{f_i}$',
'norm-base': '$\\mu_{f_i}\\,/\\,\\mu_{f_i}\\,[\\,\\eta\\,]$',
'norm-min': '$\\mu_{f_i}\\,/\\,\\min\\,[\\,\\mu_{f_i}\\,]$',
'norm-max': '$\\mu_{f_i}\\,/\\,\\max\\,[\\,\\mu_{f_i}\\,]$'
}
xlab_kwargs = dict(
y=0,
fontsize=fs['lab_norm'],
ha='center',
va='bottom',
)
ylab_kwargs = dict(
x=0,
fontsize=fs['lab_tex'],
ha='center',
va='top',
)
ylims = {
'unnormed': [0, 1],
'norm-base': [0, None],
'norm-min': [0, None],
'norm-max': [0, 1]
}
yloc = {
'unnormed': 0.5,
'norm-base': 0.5,
'norm-min': 0.5,
'norm-max': 0.5
}
title_kwargs = dict(
x=0.5,
y=1,
ha='center',
va='bottom',
fontsize=fs['tit_norm'],
)
titles = dict(
full='Including $x_\\text{dB}$',
short='Excluding $x_\\text{dB}$',
)
letter_kwargs = dict(
xref=0.01,
y=1,
ha='left',
va='center',
fontsize=fs['letter'],
)
letters = dict(
full=string.ascii_lowercase[0::2],
short=string.ascii_lowercase[1::2],
)
plateau_settings = dict(
low=0.05,
high=0.95,
first=True,
last=True,
condense=None,
)
plateau_line_kwargs = dict(
lw=lw['plateau'],
ls='--',
zorder=1,
)
plateau_dot_kwargs = dict(
marker='o',
markersize=8,
markeredgewidth=1,
clip_on=False,
)
# EXECUTION:
# Prepare overall graph:
fig = plt.figure(**fig_kwargs)
super_grid = fig.add_gridspec(**super_grid_kwargs)
# Prepare full analysis axes:
full_subfig = fig.add_subfigure(super_grid[*subfig_specs['full']])
full_grid = full_subfig.add_gridspec(**full_grid_kwargs)
full_axes = np.zeros((len(modes),), dtype=object)
for i, mode in enumerate(modes):
full_axes[i] = full_subfig.add_subplot(full_grid[i, 0])
# full_axes[i].yaxis.set_major_locator(plt.MultipleLocator(yloc[mode]))
full_axes[i].set_xscale('symlog', linthresh=0.01, linscale=0.5)
ylabel(full_axes[i], ylabels[mode], transform=full_subfig.transSubfigure, **ylab_kwargs)
if i == 0:
title_subplot(full_axes[i], titles['full'], **title_kwargs)
if i < full_grid_kwargs['nrows'] - 1:
hide_ticks(full_axes[i], 'bottom')
letter_subplots(full_axes, letters['full'], ref=full_subfig, **letter_kwargs)
# Prepare short analysis axes:
short_subfig = fig.add_subfigure(super_grid[*subfig_specs['short']])
short_grid = short_subfig.add_gridspec(**short_grid_kwargs)
short_axes = np.zeros((len(modes),), dtype=object)
for i, mode in enumerate(modes):
short_axes[i] = short_subfig.add_subplot(short_grid[i, 0])
# short_axes[i].yaxis.set_major_locator(plt.MultipleLocator(yloc[mode]))
short_axes[i].set_xscale('symlog', linthresh=0.01, linscale=0.5)
hide_ticks(short_axes[i], 'left')
if i == 0:
title_subplot(short_axes[i], titles['short'], **title_kwargs)
if i < short_grid_kwargs['nrows'] - 1:
hide_ticks(short_axes[i], 'bottom')
letter_subplots(short_axes, letters['short'], ref=short_subfig, **letter_kwargs)
super_xlabel(xlabels['super'], fig, full_axes[-1], short_axes[-1],
left_fig=full_subfig, right_fig=short_subfig, **xlab_kwargs)
# Run through normalization modes:
for mode, full_ax, short_ax in zip(modes, full_axes, short_axes):
# Load invariance data:
full_path = search_files(target_species, incl=mode, dir=full_folder)[0]
short_path = search_files(target_species, incl=mode, dir=short_folder)[0]
full_data, config = load_data(full_path, **load_kwargs)
short_data, _ = load_data(short_path, **load_kwargs)
# Reduce datasets:
if reduce_kernels:
kern_inds = find_kern_specs(config['k_specs'], kernels, types, sigmas)
full_data = reduce_kernel_set(full_data, kern_inds, **kern_subset_kwargs)
short_data = reduce_kernel_set(short_data, kern_inds, **kern_subset_kwargs)
config['k_specs'] = config['k_specs'][kern_inds, :]
config['kernels'] = config['kernels'][:, kern_inds]
if exclude_zero:
full_data = exclude_zero_scale(full_data, **scale_subset_kwargs)
short_data = exclude_zero_scale(short_data, **scale_subset_kwargs)
# Average over recordings:
full_measure = full_data['mean_feat'].mean(axis=-1)
short_measure = short_data['mean_feat'].mean(axis=-1)
# Condense over kernels:
full_median = np.nanmedian(full_measure, axis=1)
full_spread = np.nanpercentile(full_measure, percentiles, axis=1)
short_median = np.nanmedian(short_measure, axis=1)
short_spread = np.nanpercentile(short_measure, percentiles, axis=1)
# Determine shared ylims:
if None in ylims[mode]:
min_val, max_val = ylims[mode]
full_limits = ylimits(full_median, minval=min_val, maxval=max_val)
short_limits = ylimits(short_median, minval=min_val, maxval=max_val)
ylims[mode] = [min(full_limits[0], short_limits[0]),
max(full_limits[1], short_limits[1])]
if np.inf in ylims[mode]:
embed()
# Plot full analysis results:
for i, thresh in enumerate(thresh_rel):
full_ax.plot(full_data['scales'], full_median[:, i], lw=lw['feat'])
for spread in full_spread[:, :, :, i]:
full_ax.fill_between(full_data['scales'], *spread, **fill_kwargs)
full_ax.set_xlim(full_data['scales'][0], full_data['scales'][-1])
full_ax.set_ylim(ylims[mode])
# Plot short analysis results:
for i, thresh in enumerate(thresh_rel):
short_ax.plot(short_data['scales'], short_median[:, i], lw=lw['feat'])
for spread in short_spread[:, :, :, i]:
short_ax.fill_between(short_data['scales'], *spread, **fill_kwargs)
short_ax.set_xlim(short_data['scales'][0], short_data['scales'][-1])
short_ax.set_ylim(ylims[mode])
plt.show()