Files
paper_2025/python/fig_invariance_thresh-lp_species.py
j-hartling 298969a067 Cross-checked and polished remainders of fig_invariance_thresh_lp_species.pdf.
Added misc_functions.py for anything not plot-related.
2026-03-31 09:36:55 +02:00

634 lines
22 KiB
Python

import plotstyle_plt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from itertools import product
from thunderhopper.filetools import search_files
from thunderhopper.modeltools import load_data
from thunderhopper.filtertools import find_kern_specs
from misc_functions import shorten_species, get_saturation
from color_functions import load_colors, shade_colors, create_listed_cmap
from plot_functions import hide_axis, title_subplot, ylimits, xlabel, ylabel,\
plot_line, time_bar,letter_subplot, letter_subplots,\
hide_ticks, super_xlabel, reorder_by_norm
from IPython import embed
def force_sequence(*vars, skip_None=False, equal_size=False):
""" Ensures single-loop compatibility of one or several input variables.
Uses np.ndim() to separate sequence-likes (tuples, lists, >=1D arrays)
and scalar inputs (int, float, bool, 0D arrays, strings, dicts, None).
Scalar variables are promoted to 1D sequences by either tuple wrapping
or expanding by one array dimension (only 0D arrays). All single-entry
sequences can be repeated to match the length of the longest sequence.
Input variables that are None can be excluded from these treatments.
Parameters
----------
*vars : tuple (m,) of inputs (any type)
Input variables to be checked, promoted, and equalized as required.
skip_None : bool, optional
If True, None inputs fall through unmodified. The default is False.
equal_size : bool, optional
If True, counts the number of elements in each passed or promoted
sequence (using len(), meaning that elements are defined as entries
along the first sequence axis) and repeats single-element sequences to
match the maximum count. Arrays with shape[0] == 1 are not tiled but
tuple-wrapped and repeated to avoid deep copies. The default is False.
Returns
-------
vars : array-like or None or list (m,) of array-likes or Nones
Treated output variables, each either a >=1D sequence-like or None.
Single variables are returned without list wrapper.
Raises
------
ValueError
Breaks if equal_size is True and a sequence has incompatible length,
i.e. any number of elements other than 1, 0 (Nones) or the maximum.
"""
# Enforce input iterability:
vars, sizes = list(vars), []
for i, var in enumerate(vars):
if skip_None and var is None:
# Maintain None:
sizes.append(0)
continue
if np.ndim(var) == 0:
# Make each input variable at least 1D sequence-like:
vars[i] = var[None] if isinstance(var, np.ndarray) else (var,)
# Count sequence elements:
sizes.append(len(vars[i]))
# Check early exits:
if len(vars) == 1:
return vars[0]
target = max(sizes)
if not equal_size or target <= 1 or all(n == target for n in sizes):
return vars
# Validate compatibility of element counts:
if not all(n in (0, 1, target) for n in sizes):
msg = f'Given a maximum sequence length of {target}, all variables '\
f'must either have 1 or {target} elements or be None: {sizes}'
raise ValueError(msg)
# Equalize sequence length across input variables:
for i, (var, size) in enumerate(zip(vars, sizes)):
if size == 1:
vars[i] = ((var,) if isinstance(var, np.ndarray) else var) * target
return vars
def split_subplot(ax, side='right', size=10, pad=10):
""" Divides the given parent subplot into two or more separate subplots.
Opens a new axes divider on the area of the parent axes and appends a
number of child axes of given size and padding on the specified sides.
The parent's size is reduced in the process. Values passed for size and
pad are interpreted as percentages of the width (if side is 'left' or
'right') or height (if side is 'top' or 'bottom') of the remainder of
the parent. Practically, size=100 means that child and parent will be
of equal size after the split (regardless of padding) and pad=100 means
that the space between child and parent equals the parent's new width
or height. Any of side, size, or pad can be 1D sequence-likes of equal
length to perform multiple splits using the same divider. Calling this
function multiple times on the same parent subplot is possible but will
open a new and updated divider each time, making the effects of size
and pad values inconsistent between calls.
Parameters
----------
ax : matplotlib axes
Parent subplot to be divided.
side : str or 1D array-like of str (m,)
Sides of the parent subplot where new subplots are to be appended.
Options are 'bottom', 'left', 'top', 'right'. The default is 'right'.
size : int or float or 1D array-like of ints or floats (m,), optional
Horizontal or vertical extent of each child axes as percentage of width
or height of the parent axes after splitting. Multiple splits from the
same side are possible and performed in given order, with the earliest
child axes being positioned closest to the parent. The default is 10.
pad : int or float or 1D array-like of ints or floats (m,), optional
Padding between each child axes and the parent as percentage of width
or height of the parent axes after splitting. The default is 10.
Returns
-------
matplotlib axes or list of matplotlib axes (m,)
One or multiple newly appended child subplots.
"""
# Open divider on parent axes:
div = make_axes_locatable(ax)
# Split off one or multiple child axes:
if not any(np.ndim(var) for var in (side, size, pad)):
return div.append_axes(side, size=f'{size}%', pad=f'{pad}%')
inputs = zip(*force_sequence(side, size, pad, equal_size=True))
return [div.append_axes(s, f'{n}%', f'{p}%') for s, n, p in inputs]
def add_cross_axes(fig, n, long='col', fill='row', **grid_kwargs):
n_axes = n * (n - 1) // 2
nrows = grid_kwargs.get('nrows', None)
ncols = grid_kwargs.get('ncols', None)
if nrows is None or ncols is None:
if nrows is not None:
ncols = int(np.ceil(n_axes / nrows))
elif ncols is not None:
nrows = int(np.ceil(n_axes / ncols))
else:
nrows = int(np.ceil(np.sqrt(n_axes)))
ncols = int(np.ceil(n_axes / nrows))
if long == 'col' and ncols < nrows:
nrows, ncols = ncols, nrows
elif n_axes > nrows * ncols:
msg = f'Cannot place {n_axes} subplots in a {nrows}x{ncols} grid.'
raise ValueError(msg)
row_inds = [i for i in range(n) for j in range(i + 1, n)]
col_inds = [j for i in range(n) for j in range(i + 1, n)]
if fill == 'col':
positions = [(j, i) for i, j in product(range(ncols), range(nrows))]
row_inds, col_inds = col_inds, row_inds
else:
positions = list(product(range(nrows), range(ncols)))
positions = np.array(positions[:n_axes])
grid = fig.add_gridspec(**(grid_kwargs | dict(nrows=nrows, ncols=ncols)))
axes = []
for i, j in positions:
axes.append(fig.add_subplot(grid[i, j]))
return axes, positions, grid, row_inds, col_inds
# GENERAL SETTINGS:
target_species = [
'Omocestus_rufipes',
'Chorthippus_biguttulus',
'Chorthippus_mollis',
'Chrysochraon_dispar',
'Gomphocerippus_rufus',
'Pseudochorthippus_parallelus',
]
n_species = len(target_species)
load_kwargs = dict(
keywords=['scales', 'measure', 'thresh']
)
save_path = '../figures/fig_invariance_thresh_lp_species.pdf'
exclude_zero = True
show_floor = True
# SUBSET SETTINGS:
thresh_rel = np.array([0.5, 1, 3])[0]
kern_specs = np.array([
[1, 0.008],
[2, 0.004],
[3, 0.002],
])[np.array([0, 1, 2])]
n_kernels = kern_specs.shape[0]
# GRAPH SETTINGS:
fig_kwargs = dict(
figsize=(32/2.54, 32/2.54),
)
super_grid_kwargs = dict(
nrows=3,
ncols=2,
wspace=0,
hspace=0,
left=0,
right=1,
bottom=0,
top=1,
height_ratios=[1, 4, 3]
)
subfig_specs = dict(
song=(0, slice(None)),
feat=(1, slice(None)),
pure=(2, 0),
noise=(2, 1),
)
feat_grid_kwargs = dict(
nrows=2,
ncols=n_species,
wspace=0.35,
hspace=0.1,
left=0.06,
right=0.985,
bottom=0.1,
top=0.94
)
song_grid_kwargs = dict(
nrows=1,
ncols=n_species,
wspace=feat_grid_kwargs['wspace'],
hspace=0,
left=feat_grid_kwargs['left'],
right=feat_grid_kwargs['right'],
bottom=0.1,
top=0.8
)
space_grid_kwargs = dict(
nrows=None,
ncols=None,
wspace=0,
hspace=0.4,
left=0.15,
right=0.9,
bottom=0.13,
top=0.95
)
anchor_kwargs = dict(
aspect='equal',
adjustable='box',
)
inset_kwargs = dict(
y0=0.7,
w=0.3,
h=0.2,
)
# PLOT SETTINGS:
fs = dict(
lab_norm=16,
lab_tex=20,
letter=22,
tit_norm=16,
tit_tex=20,
bar=16,
)
species_colors = load_colors('../data/species_colors.npz')
kernel_shades = [0, 0.75]
scale_shades = [1, 0]
noise_colors = [(0.5, 0.5, 0.5), (0.7, 0.7, 0.7)]
lw = dict(
song=0.5,
feat=3,
kern=2.5,
plateau=3,
)
zorder = dict(
Omocestus_rufipes=2,
Chorthippus_biguttulus=2.5,
Chorthippus_mollis=2.4,
Chrysochraon_dispar=2,
Gomphocerippus_rufus=2,
Pseudochorthippus_parallelus=2,
)
space_kwargs = dict(
s=30,
)
xlabels = dict(
feat='scale $\\alpha$',
space=[f'$\\mu_{{f_{i}}}$' for i in range(1, n_kernels + 1)],
)
ylabels = dict(
feat='$\\mu_{f_i}$',
space=[f'$\\mu_{{f_{i}}}$' for i in range(1, n_kernels + 1)],
bar='scale $\\alpha$',
)
xlab_feat_kwargs = dict(
y=0,
fontsize=fs['lab_norm'],
ha='center',
va='bottom',
)
xlab_space_kwargs = dict(
y=-0.2,
fontsize=fs['lab_tex'],
ha='center',
va='top',
)
ylab_feat_kwargs = dict(
x=0,
fontsize=fs['lab_tex'],
ha='center',
va='top',
)
ylab_space_kwargs = dict(
x=-0.3,
rotation=0,
fontsize=fs['lab_tex'],
ha='right',
va='center',
)
ylab_cbar_kwargs = dict(
x=-2,
fontsize=fs['lab_norm'],
ha='center',
va='bottom',
)
xloc = dict(
feat=(1,),
space=0.5,
)
yloc = dict(
feat=0.5,
space=0.5
)
symlog_kwargs = dict(
linscale=0.5,
)
title_kwargs = dict(
x=0.5,
yref=1,
ha='center',
va='top',
fontsize=fs['tit_norm'],
fontstyle='italic'
)
letter_feat_kwargs = dict(
xref=0,
y=1,
ha='left',
va='center',
fontsize=fs['letter'],
)
letter_song_kwargs = dict(
x=0,
y=1,
ha='left',
va='top',
fontsize=fs['letter'],
)
letter_space_kwargs = dict(
x=0,
yref=1,
ha='left',
va='center',
fontsize=fs['letter'],
)
song_bar_time = 1.0
song_bar_kwargs = dict(
dur=song_bar_time,
y0=-0.1,
y1=0,
xshift=0,
color='k',
lw=0,
clip_on=False,
text_pos=(1.25, 0.5),
text_str=f'${int(song_bar_time)}\\,\\text{{s}}$',
text_kwargs=dict(
fontsize=fs['bar'],
ha='left',
va='center',
)
)
kern_bar_time = 0.05
kern_bar_kwargs = dict(
dur=kern_bar_time,
y0=0.1,
y1=0.2,
color='k',
lw=0,
clip_on=False,
text_pos=(0.6, -1),
text_str=f'${int(kern_bar_time * 1000)}\\,\\text{{ms}}$',
text_kwargs=dict(
fontsize=fs['bar'],
ha='center',
va='top',
)
)
floor_kwargs = dict(
fc=(0.9, 0.9, 0.9),
ec='none',
lw=0,
zorder=0.5,
)
plateau_settings = dict(
low=0.05,
high=0.95,
first=True,
last=True,
condense='norm',
)
# EXECUTION:
# Prepare overall graph:
fig = plt.figure(**fig_kwargs)
super_grid = fig.add_gridspec(**super_grid_kwargs)
# Prepare song axes:
song_subfig = fig.add_subfigure(super_grid[subfig_specs['song']])
song_grid = song_subfig.add_gridspec(**song_grid_kwargs)
song_axes = np.zeros((n_species,), dtype=object)
for i in range(n_species):
ax = song_subfig.add_subplot(song_grid[i])
hide_axis(ax, 'bottom')
hide_axis(ax, 'left')
song_axes[i] = ax
letter_subplot(song_subfig, 'a', **letter_song_kwargs)
# Prepare feature invariance axes:
feat_subfig = fig.add_subfigure(super_grid[subfig_specs['feat']])
feat_grid = feat_subfig.add_gridspec(**feat_grid_kwargs)
feat_axes = np.zeros((feat_grid_kwargs['nrows'], n_species), dtype=object)
for i, j in product(range(feat_grid_kwargs['nrows']), range(n_species)):
ax = feat_subfig.add_subplot(feat_grid[i, j])
ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['feat']))
ax.set_ylim(0, 1)
if j == 0:
ylabel(ax, ylabels['feat'], transform=feat_subfig, **ylab_feat_kwargs)
feat_axes[i, j] = ax
[hide_ticks(ax, side='bottom') for ax in feat_axes[0, :]]
[hide_ticks(ax, side='left') for ax in feat_axes[:, 1:].ravel()]
super_xlabel(xlabels['feat'], feat_subfig, feat_axes[-1, 0], feat_axes[-1, -1], **xlab_feat_kwargs)
letter_subplots(feat_axes[:, 0], labels='bc', ref=feat_subfig, **letter_feat_kwargs)
# Prepare kernel insets:
x0 = np.linspace(0, 1, n_kernels + 1)[:-1] + 1 / n_kernels / 2
x0 -= inset_kwargs['w'] / 2
insets = []
for i in range(n_kernels):
bounds = [x0[i], inset_kwargs['y0'], inset_kwargs['w'], inset_kwargs['h']]
inset = feat_axes[0, 0].inset_axes(bounds)
inset.set_title(rf'$k_{{{i+1}}}$', fontsize=20)
inset.axis('off')
insets.append(inset)
# Prepare pure feature space axes:
pure_subfig = fig.add_subfigure(super_grid[subfig_specs['pure']])
outputs = add_cross_axes(pure_subfig, n_kernels, **space_grid_kwargs)
pure_axes, space_pos, space_grid, row_inds, col_inds = outputs
letter_subplot(pure_subfig, 'd', ref=pure_axes[0], **letter_space_kwargs)
# Prepare noise feature space axes:
noise_subfig = fig.add_subfigure(super_grid[subfig_specs['noise']])
noise_axes = add_cross_axes(noise_subfig, n_kernels, **space_grid_kwargs)[0]
letter_subplot(noise_subfig, 'e', ref=noise_axes[0], **letter_space_kwargs)
# Format feature space axes:
for ind, axes in enumerate(zip(pure_axes, noise_axes)):
irow, icol = row_inds[ind], col_inds[ind]
for ax in axes:
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.xaxis.set_major_locator(plt.MultipleLocator(xloc['space']))
ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['space']))
anchor = space_pos[ind] / space_pos.max(axis=0)
anchor[0] = 1 - anchor[0]
ax.set_aspect(anchor=tuple(anchor[::-1]), **anchor_kwargs)
xlabel(ax, xlabels['space'][icol], **xlab_space_kwargs)
ylabel(ax, ylabels['space'][irow], **ylab_space_kwargs)
# Determine area to place colorbars:
rightmost = pure_axes[np.argmax(space_pos[:, 1])].get_position()
downmost = pure_axes[np.argmax(space_pos[:, 0])].get_position()
bar_bounds = [rightmost.x0, downmost.y0, rightmost.width, downmost.height]
# Prepare pure colorbars:
pure_bars = [pure_subfig.add_axes(bar_bounds)]
pure_bars.extend(split_subplot(pure_bars[0], side=['right'] * (n_species - 1),
size=100, pad=0))
# Prepare noise colorbars:
noise_bars = [noise_subfig.add_axes(bar_bounds)]
noise_bars.extend(split_subplot(noise_bars[0], side=['right'] * (n_species - 1),
size=100, pad=0))
# Prepare kernel-specific color shading:
kern_factors = np.linspace(*kernel_shades, n_kernels)
kern_colors_bw = shade_colors((0., 0., 0.), kern_factors)
# Plot results per species:
min_noise_feat = np.zeros((n_species, n_kernels), dtype=float)
max_pure_feat = np.zeros((n_species, n_kernels), dtype=float)
max_noise_feat = np.zeros((n_species, n_kernels), dtype=float)
pure_space_handles = {ax: [] for ax in pure_axes}
noise_space_handles = {ax: [] for ax in noise_axes}
for i, species in enumerate(target_species):
print(f'Processing {species}')
# Fetch species-specific recording file:
song_path = search_files(species, dir='../data/processed/')[0]
# Load song data:
song_data, _ = load_data(song_path, files='filt')
song, rate = song_data['filt'], song_data['filt_rate']
# Plot species snippet:
song_ax = song_axes[i]
time = np.arange(song.shape[0]) / rate
plot_line(song_ax, time, song, ypad=0.05, c='k', lw=lw['song'])
title_subplot(song_ax, shorten_species(species), ref=song_subfig, **title_kwargs)
time_bar(song_ax, **song_bar_kwargs)
song_bar_kwargs['text_pos'] = None
# Fetch species-specific invariance files:
pure_path = search_files(species, incl='pure', dir='../data/inv/thresh_lp/')[0]
noise_path = search_files(species, incl='noise', dir='../data/inv/thresh_lp/')[0]
# Load invariance data:
pure_data, config = load_data(pure_path, **load_kwargs)
noise_data, _ = load_data(noise_path, **load_kwargs)
scales = pure_data['scales']
# Reduce to kernel subset and a single threshold:
thresh_ind = np.nonzero(pure_data['thresh_rel'] == thresh_rel)[0][0]
kern_inds = find_kern_specs(config['k_specs'], kerns=kern_specs)
config['k_specs'] = config['k_specs'][kern_inds]
config['kernels'] = config['kernels'][:, kern_inds]
pure_measure = pure_data['measure_feat'][:, kern_inds, thresh_ind]
noise_measure = noise_data['measure_feat'][:, kern_inds, thresh_ind]
if exclude_zero:
# Reduce to nonzero scales:
nonzero_inds = scales > 0
scales = scales[nonzero_inds]
pure_measure = pure_measure[nonzero_inds, :]
noise_measure = noise_measure[nonzero_inds, :]
# Prepare species-specific colors:
base_color = species_colors[species]
kern_colors = shade_colors(base_color, kern_factors)
scale_factors = np.linspace(*scale_shades, scales.size)
scale_cmap = create_listed_cmap(shade_colors(base_color, scale_factors))
scale_cmap_bw = create_listed_cmap(shade_colors((0., 0., 0.), scale_factors))
# Plot feature invariance curves:
symlog_kwargs['linthresh'] = scales[scales > 0][0]
[ax.set_xscale('symlog', **symlog_kwargs) for ax in feat_axes[:, i]]
[ax.xaxis.set_major_locator(plt.LogLocator(base=10, subs=xloc['feat'])) for ax in feat_axes[:, i]]
pure_ax, noise_ax = feat_axes[:, i]
handles = pure_ax.plot(scales, pure_measure, lw=lw['feat'])
[h.set_color(c) for h, c in zip(handles, kern_colors)]
handles = noise_ax.plot(scales, noise_measure, lw=lw['feat'])
[h.set_color(c) for h, c in zip(handles, kern_colors)]
if i == 0:
# Indicate kernel waveforms:
ylims = ylimits(config['kernels'], pad=0.05)
xlims = (config['k_times'][0], config['k_times'][-1])
for kern, inset, c in zip(config['kernels'].T, insets, kern_colors_bw):
inset.plot(config['k_times'], kern, c=c, lw=lw['kern'])
inset.set_xlim(xlims)
inset.set_ylim(ylims)
# time_bar(insets[0], parent=feat_axes[0, 0], **kern_bar_kwargs)
time_bar(insets[0], **kern_bar_kwargs)
# Plot invariance curves in feature space:
norm = LogNorm(vmin=scales[scales > 0][0], vmax=scales[-1])
for ind, (pure_ax, noise_ax) in enumerate(zip(pure_axes, noise_axes)):
irow, icol = row_inds[ind], col_inds[ind]
pure_handle = pure_ax.scatter(pure_measure[:, icol], pure_measure[:, irow],
c=scales, cmap=scale_cmap, norm=norm,
zorder=zorder[species], **space_kwargs)
pure_space_handles[pure_ax].append(pure_handle)
noise_handle = noise_ax.scatter(noise_measure[:, icol], noise_measure[:, irow],
c=scales, cmap=scale_cmap, norm=norm,
zorder=zorder[species], **space_kwargs)
noise_space_handles[noise_ax].append(noise_handle)
# Indicate scale color code in pure subfigure:
pure_subfig.colorbar(pure_handle, cax=pure_bars[i])
pure_bars[i].set_yscale('symlog', **symlog_kwargs)
hide_ticks(pure_bars[i], 'right', ticks=False)
if i == 0:
pure_bars[0].tick_params(axis='y', which='both', left=True, labelleft=True)
ylabel(pure_bars[0], ylabels['bar'], **ylab_cbar_kwargs)
# Indicate scale color code in noise subfigure:
noise_subfig.colorbar(noise_handle, cax=noise_bars[i])
noise_bars[i].set_yscale('symlog', **symlog_kwargs)
hide_ticks(noise_bars[i], 'right', ticks=False)
if i == 0:
noise_bars[0].tick_params(axis='y', which='both', left=True, labelleft=True)
ylabel(noise_bars[0], ylabels['bar'], **ylab_cbar_kwargs)
# Indicate plateaus of pure invariance curves:
low_ind, high_ind = get_saturation(pure_measure, **plateau_settings)
pure_bars[i].axhline(scales[low_ind], c=noise_colors[0], lw=lw['plateau'])
pure_bars[i].axhline(scales[high_ind], c=noise_colors[1], lw=lw['plateau'])
# Indicate plateaus of noise invariance curves:
low_ind, high_ind = get_saturation(noise_measure, **plateau_settings)
noise_bars[i].axhline(scales[low_ind], c=noise_colors[0], lw=lw['plateau'])
noise_bars[i].axhline(scales[high_ind], c=noise_colors[1], lw=lw['plateau'])
# Log start and end of invariance curve:
min_noise_feat[i, :] = noise_measure.min(axis=0)
max_pure_feat[i, :] = pure_measure.max(axis=0)
max_noise_feat[i, :] = noise_measure.max(axis=0)
# Sort feature space traces by distance of endpoint to origin:
for ind, (pure_ax, noise_ax) in enumerate(zip(pure_axes, noise_axes)):
irow, icol = row_inds[ind], col_inds[ind]
reorder_by_norm(pure_space_handles[pure_ax], max_pure_feat[:, [icol, irow]])
reorder_by_norm(noise_space_handles[noise_ax], max_noise_feat[:, [icol, irow]])
if show_floor:
# Indicate feature noise floor:
noise_feat = min_noise_feat.mean(axis=0)
for ind, ax in enumerate(noise_axes):
irow, icol = row_inds[ind], col_inds[ind]
ax.add_patch(plt.Rectangle((0, 0), noise_feat[icol], noise_feat[irow], **floor_kwargs))
if save_path is not None:
fig.savefig(save_path)
plt.show()
print('Done.')
embed()