Files
paper_2025/python/fig_invariance_thresh-lp_appendix.py

192 lines
4.9 KiB
Python

import plotstyle_plt
import numpy as np
import matplotlib.pyplot as plt
from thunderhopper.filetools import search_files
from thunderhopper.modeltools import load_data
from plot_functions import ylabel, ylimits, super_xlabel, title_subplot, time_bar
from color_functions import load_colors, shade_colors
from misc_functions import shorten_species
from IPython import embed
# GENERAL SETTINGS:
target_species = [
'Chorthippus_biguttulus',
'Chorthippus_mollis',
'Chrysochraon_dispar',
'Euchorthippus_declivus',
'Gomphocerippus_rufus',
'Omocestus_rufipes',
'Pseudochorthippus_parallelus',
]
data_path = '../data/inv/thresh_lp/condensed/'
save_path = '../figures/fig_invariance_thresh-lp_appendix.pdf'
# ANALYSIS SETTINGS:
exclude_zero = True
# SUBSET SETTINGS:
thresh_rel = np.array([0.5, 1, 3])[0]
kern_specs = np.array([
[1, 0.008],
[2, 0.004],
[3, 0.002],
])[np.array([0, 1, 2])]
n_kernels = kern_specs.shape[0]
# GRAPH SETTINGS:
fig_kwargs = dict(
figsize=(32/2.54, 16/2.54),
nrows=n_kernels,
ncols=len(target_species),
sharex=True,
sharey=True,
gridspec_kw=dict(
wspace=0.4,
hspace=0.2,
left=0.07,
right=0.98,
bottom=0.1,
top=0.95,
)
)
# PLOT SETTINGS:
species_colors = load_colors('../data/species_colors.npz')
kern_shades = [0, 0.75]
kern_colors = shade_colors((0., 0., 0.), np.linspace(*kern_shades, n_kernels))
line_kwargs = dict(
lw=2,
alpha=0.5,
zorder=2,
)
fill_kwargs = dict(
alpha=0.3,
zorder=1,
)
mean_kwargs = dict(
# c=(0.5,) * 3,
lw=2,
alpha=1,
zorder=3,
ls='--'
)
mean_colors = {
'Chorthippus_biguttulus': (1,) * 3,
'Chorthippus_mollis': (0,) * 3,
'Chrysochraon_dispar': (0,) * 3,
'Euchorthippus_declivus': (0,) * 3,
'Gomphocerippus_rufus': (0,) * 3,
'Omocestus_rufipes': (0,) * 3,
'Pseudochorthippus_parallelus': (1,) * 3,
}
kern_kwargs = dict(
lw=2,
)
inset_bounds = [0.05, 0.6, 0.3, 0.25]
kern_bar_time = 0.05
kern_bar_kwargs = dict(
dur=kern_bar_time,
y0=0.1,
y1=0.2,
color='k',
lw=0,
clip_on=False,
text_pos=(0.5, -1),
text_str=f'${int(kern_bar_time * 1000)}\\,\\text{{ms}}$',
text_kwargs=dict(
fontsize=12,
ha='center',
va='top',
)
)
xlab = 'scale $\\alpha$'
ylabs = [f'$\\mu_{{f_{i}}}$' for i in range(1, n_kernels + 1)]
xlab_kwargs = dict(
y=0,
fontsize=16,
ha='center',
va='bottom',
)
ylab_kwargs = dict(
x=0,
fontsize=20,
ha='center',
va='top',
)
title_kwargs = dict(
x=0.5,
yref=0.99,
ha='center',
va='top',
fontsize=16,
fontstyle='italic',
)
letter_kwargs = dict(
x=0.005,
y=0.99,
fontsize=22,
ha='left',
va='top',
)
# Prepare graph:
fig, axes = plt.subplots(**fig_kwargs)
axes[0, 0].set_xscale('log')
axes[0, 0].set_ylim(0, 1)
axes[0, 0].yaxis.set_major_locator(plt.MultipleLocator(0.5))
super_xlabel(xlab, fig, axes[-1, 0], axes[-1, -1], **xlab_kwargs)
insets = []
for ax, ylab in zip(axes[:, 0], ylabs):
ylabel(ax, ylab, **ylab_kwargs, transform=fig.transFigure)
insets.append(ax.inset_axes(inset_bounds))
# Run through species:
for i, (species, spec_axes) in enumerate(zip(target_species, axes.T)):
title_subplot(spec_axes[0], shorten_species(species), ref=fig, **title_kwargs)
# Load species data:
path = search_files(species, dir=data_path)[0]
data, config = load_data(path, files=['scales', 'mean_feat', 'sd_feat', 'thresh_rel'])
scales = data['scales']
means = data['mean_feat']
sds = data['sd_feat']
# Reduce to single threshold:
ind = np.nonzero(data['thresh_rel'] == thresh_rel)[0][0]
means = means[:, :, ind, :]
sds = sds[:, :, ind, :]
if exclude_zero:
# Exclude zero scale:
inds = scales > 0
scales = scales[inds]
means = means[inds, :, :]
sds = sds[inds, :, :]
# Run through kernels:
for j, (ax, inset) in enumerate(zip(spec_axes, insets)):
if i == 0:
# Indicate kernel waveform:
inset.plot(config['k_times'], config['kernels'][:, j],
c=kern_colors[j], **kern_kwargs)
inset.set_xlim(config['k_times'][[0, -1]])
ylimits(config['kernels'], inset, pad=0.05)
inset.set_title(rf'$k_{{{j+1}}}$', fontsize=15)
if j == 0:
time_bar(inset, **kern_bar_kwargs)
inset.axis('off')
# Plot recording-specific traces:
for k in range(means.shape[-1]):
ax.plot(scales, means[:, j, k], c=species_colors[species], **line_kwargs)
spread = (means[:, j, k] - sds[:, j, k], means[:, j, k] + sds[:, j, k])
ax.fill_between(scales, *spread, color=species_colors[species], **fill_kwargs)
# Plot kernel-specific mean trace:
ax.plot(scales, means[:, j, :].mean(axis=-1), c=mean_colors[species], **mean_kwargs)
# Save graph:
fig.savefig(save_path)
plt.show()