Cross-checked and polished remainders of fig_invariance_thresh_lp_species.pdf.

Added misc_functions.py for anything not plot-related.
This commit is contained in:
j-hartling
2026-03-31 09:36:55 +02:00
parent 411d50ffcf
commit 298969a067
12 changed files with 251 additions and 165 deletions

View File

@@ -7,11 +7,11 @@ from itertools import product
from thunderhopper.filetools import search_files
from thunderhopper.modeltools import load_data
from thunderhopper.filtertools import find_kern_specs
from misc_functions import shorten_species, get_saturation
from color_functions import load_colors, shade_colors, create_listed_cmap
from plot_functions import hide_axis, title_subplot, ylimits, xlabel, ylabel, super_ylabel,\
plot_line, plot_barcode, strip_zeros, time_bar,\
letter_subplot, letter_subplots, hide_ticks,\
super_xlabel, super_ylabel, assign_colors
from plot_functions import hide_axis, title_subplot, ylimits, xlabel, ylabel,\
plot_line, time_bar,letter_subplot, letter_subplots,\
hide_ticks, super_xlabel, reorder_by_norm
from IPython import embed
def force_sequence(*vars, skip_None=False, equal_size=False):
@@ -126,10 +126,6 @@ def split_subplot(ax, side='right', size=10, pad=10):
inputs = zip(*force_sequence(side, size, pad, equal_size=True))
return [div.append_axes(s, f'{n}%', f'{p}%') for s, n, p in inputs]
def shorten_species(name):
genus, species = name.split('_')
return genus[0] + '. ' + species
def add_cross_axes(fig, n, long='col', fill='row', **grid_kwargs):
n_axes = n * (n - 1) // 2
nrows = grid_kwargs.get('nrows', None)
@@ -179,7 +175,7 @@ load_kwargs = dict(
)
save_path = '../figures/fig_invariance_thresh_lp_species.pdf'
exclude_zero = True
show_noise = True
show_floor = True
# SUBSET SETTINGS:
thresh_rel = np.array([0.5, 1, 3])[0]
@@ -214,7 +210,7 @@ subfig_specs = dict(
feat_grid_kwargs = dict(
nrows=2,
ncols=n_species,
wspace=0.25,
wspace=0.35,
hspace=0.1,
left=0.06,
right=0.985,
@@ -234,17 +230,16 @@ song_grid_kwargs = dict(
space_grid_kwargs = dict(
nrows=None,
ncols=None,
wspace=0.1,
hspace=0.3,
left=0.05,
right=1,
bottom=0.1,
wspace=0,
hspace=0.4,
left=0.15,
right=0.9,
bottom=0.13,
top=0.95
)
anchor_kwargs = dict(
aspect='equal',
adjustable='box',
anchor=(0.5, 0.5)
)
inset_kwargs = dict(
y0=0.7,
@@ -264,10 +259,12 @@ fs = dict(
species_colors = load_colors('../data/species_colors.npz')
kernel_shades = [0, 0.75]
scale_shades = [1, 0]
noise_colors = [(0.5, 0.5, 0.5), (0.7, 0.7, 0.7)]
lw = dict(
song=0.5,
feat=3,
kern=3
kern=2.5,
plateau=3,
)
zorder = dict(
Omocestus_rufipes=2,
@@ -285,7 +282,7 @@ xlabels = dict(
space=[f'$\\mu_{{f_{i}}}$' for i in range(1, n_kernels + 1)],
)
ylabels = dict(
feat='$\\mu_f$',
feat='$\\mu_{f_i}$',
space=[f'$\\mu_{{f_{i}}}$' for i in range(1, n_kernels + 1)],
bar='scale $\\alpha$',
)
@@ -296,10 +293,10 @@ xlab_feat_kwargs = dict(
va='bottom',
)
xlab_space_kwargs = dict(
y=-0.3,
y=-0.2,
fontsize=fs['lab_tex'],
ha='center',
va='bottom',
va='top',
)
ylab_feat_kwargs = dict(
x=0,
@@ -308,13 +305,14 @@ ylab_feat_kwargs = dict(
va='top',
)
ylab_space_kwargs = dict(
x=-0.2,
x=-0.3,
rotation=0,
fontsize=fs['lab_tex'],
ha='center',
va='bottom',
ha='right',
va='center',
)
ylab_cbar_kwargs = dict(
x=1,
x=-2,
fontsize=fs['lab_norm'],
ha='center',
va='bottom',
@@ -368,30 +366,43 @@ song_bar_kwargs = dict(
color='k',
lw=0,
clip_on=False,
# text_pos=(-0.1, 0.5),
# text_str=f'${int(1000 * song_bar_time)}\\,\\text{{ms}}$',
# text_kwargs=dict(
# fontsize=fs['bar'],
# ha='right',
# va='center',
# )
text_pos=(1.25, 0.5),
text_str=f'${int(song_bar_time)}\\,\\text{{s}}$',
text_kwargs=dict(
fontsize=fs['bar'],
ha='left',
va='center',
)
)
kern_bar_time = 0.05
kern_bar_kwargs = dict(
dur=kern_bar_time,
y0=inset_kwargs['y0'] - 0.03,
y1=inset_kwargs['y0'],
y0=0.1,
y1=0.2,
color='k',
lw=0
lw=0,
clip_on=False,
text_pos=(0.6, -1),
text_str=f'${int(kern_bar_time * 1000)}\\,\\text{{ms}}$',
text_kwargs=dict(
fontsize=fs['bar'],
ha='center',
va='top',
)
)
noise_kwargs = dict(
floor_kwargs = dict(
fc=(0.9, 0.9, 0.9),
ec='none',
lw=0,
zorder=0.5,
)
low_rel_thresh = 0.05
high_rel_thresh = 0.95
plateau_settings = dict(
low=0.05,
high=0.95,
first=True,
last=True,
condense='norm',
)
# EXECUTION:
@@ -450,14 +461,17 @@ letter_subplot(noise_subfig, 'e', ref=noise_axes[0], **letter_space_kwargs)
# Format feature space axes:
for ind, axes in enumerate(zip(pure_axes, noise_axes)):
irow, icol = row_inds[ind], col_inds[ind]
for ax in axes:
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.xaxis.set_major_locator(plt.MultipleLocator(xloc['space']))
ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['space']))
ax.set_aspect(**anchor_kwargs)
xlabel(ax, xlabels['space'][col_inds[ind]], **xlab_space_kwargs)
ylabel(ax, ylabels['space'][row_inds[ind]], **ylab_space_kwargs)
anchor = space_pos[ind] / space_pos.max(axis=0)
anchor[0] = 1 - anchor[0]
ax.set_aspect(anchor=tuple(anchor[::-1]), **anchor_kwargs)
xlabel(ax, xlabels['space'][icol], **xlab_space_kwargs)
ylabel(ax, ylabels['space'][irow], **ylab_space_kwargs)
# Determine area to place colorbars:
rightmost = pure_axes[np.argmax(space_pos[:, 1])].get_position()
@@ -479,7 +493,11 @@ kern_factors = np.linspace(*kernel_shades, n_kernels)
kern_colors_bw = shade_colors((0., 0., 0.), kern_factors)
# Plot results per species:
noise_feat = np.zeros((n_species, n_kernels), dtype=float)
min_noise_feat = np.zeros((n_species, n_kernels), dtype=float)
max_pure_feat = np.zeros((n_species, n_kernels), dtype=float)
max_noise_feat = np.zeros((n_species, n_kernels), dtype=float)
pure_space_handles = {ax: [] for ax in pure_axes}
noise_space_handles = {ax: [] for ax in noise_axes}
for i, species in enumerate(target_species):
print(f'Processing {species}')
@@ -496,6 +514,7 @@ for i, species in enumerate(target_species):
plot_line(song_ax, time, song, ypad=0.05, c='k', lw=lw['song'])
title_subplot(song_ax, shorten_species(species), ref=song_subfig, **title_kwargs)
time_bar(song_ax, **song_bar_kwargs)
song_bar_kwargs['text_pos'] = None
# Fetch species-specific invariance files:
pure_path = search_files(species, incl='pure', dir='../data/inv/thresh_lp/')[0]
@@ -545,7 +564,8 @@ for i, species in enumerate(target_species):
inset.plot(config['k_times'], kern, c=c, lw=lw['kern'])
inset.set_xlim(xlims)
inset.set_ylim(ylims)
time_bar(insets[0], parent=feat_axes[0, 0], **kern_bar_kwargs)
# time_bar(insets[0], parent=feat_axes[0, 0], **kern_bar_kwargs)
time_bar(insets[0], **kern_bar_kwargs)
# Plot invariance curves in feature space:
norm = LogNorm(vmin=scales[scales > 0][0], vmax=scales[-1])
@@ -554,60 +574,56 @@ for i, species in enumerate(target_species):
pure_handle = pure_ax.scatter(pure_measure[:, icol], pure_measure[:, irow],
c=scales, cmap=scale_cmap, norm=norm,
zorder=zorder[species], **space_kwargs)
pure_space_handles[pure_ax].append(pure_handle)
noise_handle = noise_ax.scatter(noise_measure[:, icol], noise_measure[:, irow],
c=scales, cmap=scale_cmap, norm=norm,
zorder=zorder[species], **space_kwargs)
noise_space_handles[noise_ax].append(noise_handle)
# Indicate scale color code in pure subfigure:
pure_subfig.colorbar(pure_handle, cax=pure_bars[i])
pure_bars[i].set_yscale('symlog', **symlog_kwargs)
if i < n_species - 1:
hide_ticks(pure_bars[i], 'right', ticks=False)
else:
ylabel(pure_bars[i], ylabels['bar'], transform=pure_subfig.transSubfigure, **ylab_cbar_kwargs)
hide_ticks(pure_bars[i], 'right', ticks=False)
if i == 0:
pure_bars[0].tick_params(axis='y', which='both', left=True, labelleft=True)
ylabel(pure_bars[0], ylabels['bar'], **ylab_cbar_kwargs)
# Indicate scale color code in noise subfigure:
noise_subfig.colorbar(noise_handle, cax=noise_bars[i])
noise_bars[i].set_yscale('symlog', **symlog_kwargs)
if i < n_species - 1:
hide_ticks(noise_bars[i], 'right', ticks=False)
else:
ylabel(noise_bars[i], ylabels['bar'], transform=noise_subfig.transSubfigure, **ylab_cbar_kwargs)
hide_ticks(noise_bars[i], 'right', ticks=False)
if i == 0:
noise_bars[0].tick_params(axis='y', which='both', left=True, labelleft=True)
ylabel(noise_bars[0], ylabels['bar'], **ylab_cbar_kwargs)
# Log feature noise floor:
noise_feat[i, :] = noise_measure.min(axis=0)
# Indicate plateaus of pure invariance curves:
low_ind, high_ind = get_saturation(pure_measure, **plateau_settings)
pure_bars[i].axhline(scales[low_ind], c=noise_colors[0], lw=lw['plateau'])
pure_bars[i].axhline(scales[high_ind], c=noise_colors[1], lw=lw['plateau'])
# Indicate low and high plateaus:
min_feat = pure_measure.min(axis=0)
span_feat = pure_measure.max(axis=0) - min_feat
# Indicate plateaus of noise invariance curves:
low_ind, high_ind = get_saturation(noise_measure, **plateau_settings)
noise_bars[i].axhline(scales[low_ind], c=noise_colors[0], lw=lw['plateau'])
noise_bars[i].axhline(scales[high_ind], c=noise_colors[1], lw=lw['plateau'])
# Log start and end of invariance curve:
min_noise_feat[i, :] = noise_measure.min(axis=0)
max_pure_feat[i, :] = pure_measure.max(axis=0)
max_noise_feat[i, :] = noise_measure.max(axis=0)
low_thresh = min_feat + low_rel_thresh * span_feat
low_ind = np.nonzero((pure_measure >= low_thresh).all(axis=1))[0][0]
pure_bars[i].axhline(scales[low_ind], c='k', lw=3)
# Sort feature space traces by distance of endpoint to origin:
for ind, (pure_ax, noise_ax) in enumerate(zip(pure_axes, noise_axes)):
irow, icol = row_inds[ind], col_inds[ind]
reorder_by_norm(pure_space_handles[pure_ax], max_pure_feat[:, [icol, irow]])
reorder_by_norm(noise_space_handles[noise_ax], max_noise_feat[:, [icol, irow]])
high_thresh = min_feat + high_rel_thresh * span_feat
high_ind = np.nonzero((pure_measure >= high_thresh).any(axis=1))[0][0]
pure_bars[i].axhline(scales[high_ind], c='w', lw=3)
# Indicate low and high plateaus:
min_feat = noise_measure.min(axis=0)
span_feat = noise_measure.max(axis=0) - min_feat
low_thresh = min_feat + low_rel_thresh * span_feat
low_ind = np.nonzero((noise_measure >= low_thresh).all(axis=1))[0][0]
noise_bars[i].axhline(scales[low_ind], c='k', lw=3)
high_thresh = min_feat + high_rel_thresh * span_feat
high_ind = np.nonzero((noise_measure >= high_thresh).any(axis=1))[0][0]
noise_bars[i].axhline(scales[high_ind], c='w', lw=3)
if show_noise:
if show_floor:
# Indicate feature noise floor:
noise_feat = noise_feat.mean(axis=0)
noise_feat = min_noise_feat.mean(axis=0)
for ind, ax in enumerate(noise_axes):
irow, icol = row_inds[ind], col_inds[ind]
ax.add_patch(plt.Rectangle((0, 0), noise_feat[icol], noise_feat[irow], **noise_kwargs))
ax.add_patch(plt.Rectangle((0, 0), noise_feat[icol], noise_feat[irow], **floor_kwargs))
if save_path is not None:
fig.savefig(save_path)