Files
paper_2025/python/fig_invariance_thresh-lp_single.py
j-hartling e70d100655 Added loads of units in nearly all graphs.
Overhauled fig_invariance_full.pdf.
Added some legends, somewhere.
2026-04-28 19:43:05 +02:00

479 lines
15 KiB
Python

import plotstyle_plt
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from thunderhopper.filetools import search_files
from thunderhopper.modeltools import load_data
from thunderhopper.filtertools import find_kern_specs
from misc_functions import get_saturation
from color_functions import load_colors, shade_colors
from plot_functions import shift_subplot, hide_axis, ylimits, xlabel, ylabel,\
super_ylabel, plot_line, plot_barcode, strip_zeros,\
time_bar, letter_subplot, letter_subplots, title_subplot,\
set_clip_box
from IPython import embed
def add_snip_axes(fig, grid_kwargs):
grid = fig.add_gridspec(**grid_kwargs)
axes = np.zeros((grid.nrows, grid.ncols), dtype=object)
for i, j in product(range(grid.nrows), range(grid.ncols)):
axes[i, j] = fig.add_subplot(grid[i, j])
if j == 0:
shift_subplot(axes[i, j], dx=snip_col_shift)
[hide_axis(ax, 'left') for ax in axes[:, 2:].flatten()]
[hide_axis(ax, 'bottom') for ax in axes.flatten()]
return axes
def plot_snippets(axes, time, snippets, ymin=None, ymax=None, ypad=0.05,
thresh=None, fill_kwargs={}, **kwargs):
ymin, ymax = ylimits(snippets, minval=ymin, maxval=ymax, pad=ypad)
handles = []
for ax, snippet in zip(axes, snippets.T):
handles.append(plot_line(ax, time, snippet, ymin=ymin, ymax=ymax, **kwargs))
if thresh is not None:
ax.fill_between(time, thresh, snippet, where=(snippet > thresh), **fill_kwargs)
return handles
def plot_bi_snippets(axes, time, binary, **kwargs):
for ax, binary in zip(axes, binary.T):
plot_barcode(ax, time, binary[:, None], **kwargs)
return None
def side_distributions(axes, snippets, inset_bounds, thresh, nbins=50,
limits=None, fill_kwargs={}, **kwargs):
if limits is None:
limits = np.array([snippets.min(), snippets.max()]) * 1.1
edges = np.linspace(*limits, nbins + 1)
centers = edges[:-1] + (edges[1] - edges[0]) / 2
insets = []
for ax, snippet in zip(axes, snippets.T):
pdf, _ = np.histogram(snippet, edges, density=True)
inset = ax.inset_axes(inset_bounds)
handle = inset.plot(pdf, centers, **kwargs)[0]
set_clip_box(handle, inset, bounds=[[-0.05, 0], [1.05, 1]])
handle = inset.fill_betweenx(centers, pdf.min(), pdf, where=(centers > thresh), **fill_kwargs)
set_clip_box(handle, inset, bounds=[[-0.05, 0], [1.05, 1]])
inset.set_xlim(0, pdf.max())
inset.set_ylim(ax.get_ylim())
inset.axis('off')
insets.append(inset)
return insets
# GENERAL SETTINGS:
example_file = 'Omocestus_rufipes_DJN_32-40s724ms-48s779ms'
data_path = search_files(example_file, incl='noise', dir='../data/inv/thresh_lp/')[0]
stages = ['conv', 'bi', 'feat']
load_kwargs = dict(
files=stages,
keywords=['scales', 'snip', 'measure', 'thresh']
)
save_path = '../figures/fig_invariance_thresh_lp_single.pdf'
exclude_zero = True
# GRAPH SETTINGS:
fig_kwargs = dict(
figsize=(32/2.54, 32/2.54),
)
super_grid_kwargs = dict(
nrows=None,
ncols=3,
wspace=0,
hspace=0,
left=0,
right=1,
bottom=0,
top=1,
)
input_rows = 1
snip_rows = 2
subfig_specs = dict(
input=(slice(input_rows), slice(-1)),
snip=[np.array([input_rows, input_rows + snip_rows]), slice(-1)],
big=(slice(None), -1),
)
snip_col_shift = -0.07
snip_grid_kwargs = dict(
nrows=len(stages),
ncols=None,
wspace=0.3,
hspace=0,
left=0.23 - snip_col_shift,
right=0.93,
bottom=0.15,
top=0.95,
height_ratios=[4, 1, 2]
)
input_grid_kwargs = dict(
nrows=1,
ncols=None,
wspace=snip_grid_kwargs['wspace'],
hspace=0,
left=snip_grid_kwargs['left'],
right=snip_grid_kwargs['right'],
bottom=0.15,
top=0.75,
)
big_grid_kwargs = dict(
nrows=2,
ncols=1,
wspace=0,
hspace=0.15,
left=0.2,
right=0.96,
bottom=0.05,
top=0.99
)
dist_inset_bounds = [1.02, 0, 0.2, 1]
# PLOT SETTINGS:
fs = dict(
lab_norm=16,
lab_tex=20,
letter=22,
tit_norm=16,
tit_tex=20,
bar=16,
)
colors = load_colors('../data/stage_colors.npz')
shade_factors = [0.2, -0.2]
lw = dict(
inv=1.5,
conv=1.5,
bi=0.1,
feat=3,
big=4,
thresh=1.5,
kern=2.5,
plateau=1.5,
)
xlabels = dict(
alpha='scale $\\alpha$',
sigma='$\\sigma_{\\text{adapt}}$',
)
ylabels = dict(
inv='$x_{\\text{adapt}}$\n$[\\text{dB}]$',
conv='$c_i$\n$[\\text{dB}]$',
bi='$b_i$',
feat='$f_i$',
big='$\\mu_{f_i}$',
)
xlab_alpha_kwargs = dict(
y=0.5,
fontsize=fs['lab_norm'],
ha='center',
va='bottom',
)
xlab_sigma_kwargs = dict(
y=0,
fontsize=fs['lab_tex'],
ha=xlab_alpha_kwargs['ha'],
va='bottom',
)
ylab_snip_kwargs = dict(
x=0.1,
fontsize=fs['lab_tex'],
rotation=0,
ha='center',
va='center',
)
ylab_super_kwargs = dict(
x=0,
fontsize=fs['lab_tex'],
ha='left',
va='center',
)
ylab_big_kwargs = dict(
x=0,
fontsize=fs['lab_tex'],
ha='center',
va='top',
)
ypad = dict(
inv=0.05,
conv=0.05,
big=0.1
)
yloc = dict(
inv=(2, 200),
conv=(0.02, 2),
bi=(1, 1),
feat=(1, 1),
big=0.2,
)
title_kwargs = dict(
x=0.5,
yref=1,
ha='center',
va='top',
fontsize=fs['tit_norm'],
)
letter_snip_kwargs = dict(
x=0,
y=1,
ha='left',
va='top',
fontsize=fs['letter'],
)
letter_big_kwargs = dict(
xref=0,
y=1,
ha='left',
va='top',
fontsize=fs['letter'],
)
kern_kwargs = dict(
c='k',
lw=lw['kern'],
)
dist_kwargs = dict(
c='k',
lw=1,
)
dist_fill_kwargs = dict(
color=colors['bi'],
lw=0.1,
)
thresh_kwargs = dict(
color='k',
lw=lw['thresh'],
ls='--',
zorder=3,
)
bar_time = 0.1
bar_kwargs = dict(
dur=bar_time,
y0=-0.5,
y1=-0.35,
xshift=1,
color='k',
lw=0,
clip_on=False,
text_pos=(-0.1, 0.5),
text_str=f'${int(1000 * bar_time)}\\,\\text{{ms}}$',
text_kwargs=dict(
fontsize=fs['bar'],
ha='right',
va='center',
)
)
leg_kwargs = dict(
ncols=2,
loc='center',
bbox_to_anchor=(0, 0.95, 1, 0.05),
frameon=False,
fontsize=fs['tit_norm'],
handlelength=1.5,
columnspacing=1,
)
plateau_settings = dict(
low=0.05,
high=0.95,
first=True,
last=True,
condense=None,
)
plateau_line_kwargs = dict(
lw=lw['plateau'],
ls='--',
zorder=1,
)
plateau_dot_kwargs = dict(
marker='o',
markersize=8,
markeredgewidth=1,
clip_on=False,
)
zoom_rel = np.array([0.5, 0.515])
# SUBSET SETTINGS:
kern_specs = np.array([
[1, 0.008],
[2, 0.004],
[3, 0.002],
])[np.array([1])]
# EXECUTION:
print(f'Processing {data_path}')
# Load invariance data:
noise_data, config = load_data(data_path, **load_kwargs)
pure_data, _ = load_data(data_path.replace('noise', 'pure'), **load_kwargs)
# Unpack shared variables:
scales = noise_data['scales']
plot_scales = noise_data['example_scales']
thresh_rel = noise_data['thresh_rel']
thresh_abs = noise_data['thresh_abs']
# Reduce to kernel subset and crop to zoom frame:
t_full = np.arange(noise_data['snip_conv'].shape[0]) / config['env_rate']
zoom_abs = zoom_rel * t_full[-1]
zoom_inds = (t_full >= zoom_abs[0]) & (t_full <= zoom_abs[1])
kern_ind = find_kern_specs(config['k_specs'], kerns=kern_specs)[0]
noise_data['snip_inv'] = noise_data['snip_inv'][zoom_inds, :]
noise_data['snip_conv'] = noise_data['snip_conv'][zoom_inds, kern_ind, :]
noise_data['snip_bi'] = noise_data['snip_bi'][zoom_inds, kern_ind, :, :]
noise_data['snip_feat'] = noise_data['snip_feat'][zoom_inds, kern_ind, :, :]
noise_data['measure_feat'] = noise_data['measure_feat'][:, kern_ind, :]
pure_data['measure_feat'] = pure_data['measure_feat'][:, kern_ind, :]
config['kernels'] = config['kernels'][:, kern_ind]
thresh_abs = thresh_abs[:, kern_ind]
t_full = np.arange(noise_data['snip_conv'].shape[0]) / config['env_rate']
if exclude_zero:
# Exclude zero scale:
inds = scales > 0
scales = scales[inds]
noise_data['measure_inv'] = noise_data['measure_inv'][inds]
noise_data['measure_feat'] = noise_data['measure_feat'][inds, :]
pure_data['measure_feat'] = pure_data['measure_feat'][inds, :]
# Get threshold-specific colors:
factors = np.linspace(*shade_factors, thresh_rel.size)
shaded = dict(
conv=shade_colors(colors['conv'], factors),
bi=shade_colors(colors['bi'], factors),
feat=shade_colors(colors['feat'], factors),
)
# Adjust grid parameters to loaded data:
super_grid_kwargs['nrows'] = snip_rows * thresh_rel.size + input_rows
input_grid_kwargs['ncols'] = plot_scales.size
snip_grid_kwargs['ncols'] = plot_scales.size
# Prepare overall graph:
fig = plt.figure(**fig_kwargs)
super_grid = fig.add_gridspec(**super_grid_kwargs)
# Prepare input snippet axes:
input_subfig = fig.add_subfigure(super_grid[subfig_specs['input']])
input_axes = add_snip_axes(input_subfig, input_grid_kwargs).ravel()
input_axes[0].yaxis.set_major_locator(plt.MultipleLocator(yloc['inv'][0]))
input_axes[1].yaxis.set_major_locator(plt.MultipleLocator(yloc['inv'][1]))
ylabel(input_axes[0], ylabels['inv'], transform=input_subfig.transSubfigure, **ylab_snip_kwargs)
for ax, scale in zip(input_axes, plot_scales):
title_subplot(ax, f'$\\alpha={strip_zeros(scale)}$', ref=input_subfig, **title_kwargs)
letter_subplot(input_subfig, 'a', **letter_snip_kwargs)
# Prepare snippet axes:
snip_subfigs, snip_axes = [], []
for i in range(thresh_rel.size):
subfig_spec = subfig_specs['snip'].copy()
subfig_spec[0] = slice(*(subfig_spec[0] + i * snip_rows))
snip_subfig = fig.add_subfigure(super_grid[*subfig_spec])
axes = add_snip_axes(snip_subfig, snip_grid_kwargs)
low_box = axes[-1, 0].get_position()
high_box = axes[0, 0].get_position()
[hide_axis(ax, 'left') for ax in axes[1:, 1]]
super_ylabel(f'$\\Theta_i={strip_zeros(thresh_rel[i])}\\cdot\\sigma_{{\\eta}}$',
snip_subfig, axes[-1, 0], axes[0, 0], **ylab_super_kwargs)
for (ax1, ax2), stage in zip(axes[:, :2], stages):
ax1.yaxis.set_major_locator(plt.MultipleLocator(yloc[stage][0]))
ax2.yaxis.set_major_locator(plt.MultipleLocator(yloc[stage][1]))
ylabel(ax1, ylabels[stage], transform=snip_subfig.transSubfigure, **ylab_snip_kwargs)
if i == thresh_rel.size - 1:
axes[-1, -1].set_xlim(t_full[0], t_full[-1])
time_bar(axes[-1, -1], **bar_kwargs)
snip_subfigs.append(snip_subfig)
snip_axes.append(axes)
letter_subplots(snip_subfigs, 'bcd', **letter_snip_kwargs)
# Prepare analysis axes:
big_subfig = fig.add_subfigure(super_grid[subfig_specs['big']])
big_grid = big_subfig.add_gridspec(**big_grid_kwargs)
alpha_ax = big_subfig.add_subplot(big_grid[0, 0])
alpha_ax.set_xlim(scales[0], scales[-1])
alpha_ax.set_xscale('symlog', linthresh=scales[scales > 0][0], linscale=0.5)
ylimits(pure_data['measure_feat'], alpha_ax, minval=0, pad=ypad['big'])
alpha_ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['big']))
xlabel(alpha_ax, xlabels['alpha'], **xlab_alpha_kwargs, transform=big_subfig)
ylabel(alpha_ax, ylabels['big'], transform=big_subfig.transSubfigure, **ylab_big_kwargs)
letter_subplot(alpha_ax, 'e', ref=big_subfig, **letter_big_kwargs)
sigma_ax = big_subfig.add_subplot(big_grid[1, 0])
sigma_ax.set_xlim(1, noise_data['measure_inv'].max())
sigma_ax.set_xscale('symlog', linthresh=scales[scales > 0][0], linscale=0.5)
ylimits(pure_data['measure_feat'], sigma_ax, minval=0, pad=ypad['big'])
sigma_ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['big']))
xlabel(sigma_ax, xlabels['sigma'], **xlab_sigma_kwargs, transform=big_subfig)
ylabel(sigma_ax, ylabels['big'], transform=big_subfig.transSubfigure, **ylab_big_kwargs)
letter_subplot(sigma_ax, 'f', ref=big_subfig, **letter_big_kwargs)
# Plot intensity-adapted snippets:
plot_snippets(input_axes, t_full, noise_data['snip_inv'],
ypad=ypad['inv'], c=colors['inv'], lw=lw['inv'])
ylimits(noise_data['snip_inv'][:, 0], input_axes[0], pad=ypad['inv'])
# Indicate kernel waveform over 1st intensity-adapted snippet:
input_axes[0].plot(config['k_times'] + 0.5 * t_full[-1], config['kernels'], **kern_kwargs)
# Plot representation snippets per threshold:
for i, (subfig, axes) in enumerate(zip(snip_subfigs, snip_axes)):
dist_fill_kwargs['color'] = shaded['bi'][i]
# Plot kernel response snippets:
plot_snippets(axes[0, :], t_full, noise_data['snip_conv'], thresh=thresh_abs[i],
ypad=ypad['conv'], fill_kwargs=dist_fill_kwargs, c=shaded['conv'][i], lw=lw['conv'])
ylim_zoom = ylimits(noise_data['snip_conv'][:, 0], axes[0, 0],
pad=ypad['conv'], maxval=thresh_abs[-1])
# Indicate absolute threshold value:
handle = axes[0, 0].axhline(thresh_abs[i], **thresh_kwargs)
set_clip_box(handle, axes[0, 0], bounds=[[0, 0], [1, 1.05]])
# Plot kernel response distributions:
side_distributions(axes[0, :1], noise_data['snip_conv'][:, :1], dist_inset_bounds,
thresh_abs[i], nbins=50, limits=ylim_zoom, fill_kwargs=dist_fill_kwargs, **dist_kwargs)
side_distributions(axes[0, 1:], noise_data['snip_conv'][:, 1:], dist_inset_bounds,
thresh_abs[i], nbins=50, fill_kwargs=dist_fill_kwargs, **dist_kwargs)
# Plot binary snippets:
plot_bi_snippets(axes[1, :], t_full, noise_data['snip_bi'][:, :, i],
color=shaded['bi'][i], lw=lw['bi'])
# Plot feature snippets:
handles = plot_snippets(axes[2, :], t_full, noise_data['snip_feat'][:, :, i],
ymin=0, ymax=1, c=shaded['feat'][i], lw=lw['feat'])
[set_clip_box(h[0], ax, bounds=[[0, -0.05], [1, 1.05]]) for h, ax in zip(handles, axes[2, :])]
# Get saturation:
saturation_inds = []
for i in range(thresh_rel.size):
ind = get_saturation(noise_data['measure_feat'][:, i], **plateau_settings)[1]
saturation_inds.append(ind)
# Plot analysis results:
for ax, x in zip([alpha_ax, sigma_ax], [scales, noise_data['measure_inv']]):
# Plot pure-song analysis results:
handles = ax.plot(x, pure_data['measure_feat'], lw=lw['big'], ls='dotted')
[h.set_color(c) for h, c in zip(handles, shaded['feat'])]
# Plot noise-song analysis results:
handles = ax.plot(x, noise_data['measure_feat'], lw=lw['big'])
[h.set_color(c) for h, c in zip(handles, shaded['feat'])]
# Indicate threshold-specific saturation:
for i, ind in enumerate(saturation_inds):
color = shaded['feat'][i]
ax.plot(x[ind], 0, c='w', alpha=1, zorder=5.5, **plateau_dot_kwargs,
transform=ax.get_xaxis_transform())
ax.plot(x[ind], 0, mfc=color, mec='k', alpha=0.75, zorder=6,
**plateau_dot_kwargs, transform=ax.get_xaxis_transform())
ax.vlines(x[ind], ax.get_ylim()[0], noise_data['measure_feat'][ind, i],
color=color, **plateau_line_kwargs)
# Add proxy legend:
if ax == alpha_ax:
h1 = ax.plot([], [], c='k', lw=lw['big'], label='$\\alpha\\cdot s(t) + \\eta(t)$')[0]
h2 = ax.plot([], [], c='k', lw=lw['big'], ls='dotted', label='$\\alpha\\cdot s(t)$')[0]
ax.legend(handles=[h1, h2], **leg_kwargs)
if save_path is not None:
fig.savefig(save_path)
plt.show()
print('Done.')
embed()