479 lines
15 KiB
Python
479 lines
15 KiB
Python
import plotstyle_plt
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from itertools import product
|
|
from thunderhopper.filetools import search_files
|
|
from thunderhopper.modeltools import load_data
|
|
from thunderhopper.filtertools import find_kern_specs
|
|
from misc_functions import get_saturation
|
|
from color_functions import load_colors, shade_colors
|
|
from plot_functions import shift_subplot, hide_axis, ylimits, xlabel, ylabel,\
|
|
super_ylabel, plot_line, plot_barcode, strip_zeros,\
|
|
time_bar, letter_subplot, letter_subplots, title_subplot,\
|
|
set_clip_box
|
|
from IPython import embed
|
|
|
|
def add_snip_axes(fig, grid_kwargs):
|
|
grid = fig.add_gridspec(**grid_kwargs)
|
|
axes = np.zeros((grid.nrows, grid.ncols), dtype=object)
|
|
for i, j in product(range(grid.nrows), range(grid.ncols)):
|
|
axes[i, j] = fig.add_subplot(grid[i, j])
|
|
if j == 0:
|
|
shift_subplot(axes[i, j], dx=snip_col_shift)
|
|
[hide_axis(ax, 'left') for ax in axes[:, 2:].flatten()]
|
|
[hide_axis(ax, 'bottom') for ax in axes.flatten()]
|
|
return axes
|
|
|
|
def plot_snippets(axes, time, snippets, ymin=None, ymax=None, ypad=0.05,
|
|
thresh=None, fill_kwargs={}, **kwargs):
|
|
ymin, ymax = ylimits(snippets, minval=ymin, maxval=ymax, pad=ypad)
|
|
handles = []
|
|
for ax, snippet in zip(axes, snippets.T):
|
|
handles.append(plot_line(ax, time, snippet, ymin=ymin, ymax=ymax, **kwargs))
|
|
if thresh is not None:
|
|
ax.fill_between(time, thresh, snippet, where=(snippet > thresh), **fill_kwargs)
|
|
return handles
|
|
|
|
def plot_bi_snippets(axes, time, binary, **kwargs):
|
|
for ax, binary in zip(axes, binary.T):
|
|
plot_barcode(ax, time, binary[:, None], **kwargs)
|
|
return None
|
|
|
|
def side_distributions(axes, snippets, inset_bounds, thresh, nbins=50,
|
|
limits=None, fill_kwargs={}, **kwargs):
|
|
if limits is None:
|
|
limits = np.array([snippets.min(), snippets.max()]) * 1.1
|
|
edges = np.linspace(*limits, nbins + 1)
|
|
centers = edges[:-1] + (edges[1] - edges[0]) / 2
|
|
insets = []
|
|
for ax, snippet in zip(axes, snippets.T):
|
|
pdf, _ = np.histogram(snippet, edges, density=True)
|
|
inset = ax.inset_axes(inset_bounds)
|
|
handle = inset.plot(pdf, centers, **kwargs)[0]
|
|
set_clip_box(handle, inset, bounds=[[-0.05, 0], [1.05, 1]])
|
|
handle = inset.fill_betweenx(centers, pdf.min(), pdf, where=(centers > thresh), **fill_kwargs)
|
|
set_clip_box(handle, inset, bounds=[[-0.05, 0], [1.05, 1]])
|
|
inset.set_xlim(0, pdf.max())
|
|
inset.set_ylim(ax.get_ylim())
|
|
inset.axis('off')
|
|
insets.append(inset)
|
|
return insets
|
|
|
|
|
|
# GENERAL SETTINGS:
|
|
example_file = 'Omocestus_rufipes_DJN_32-40s724ms-48s779ms'
|
|
data_path = search_files(example_file, incl='noise', dir='../data/inv/thresh_lp/')[0]
|
|
stages = ['conv', 'bi', 'feat']
|
|
load_kwargs = dict(
|
|
files=stages,
|
|
keywords=['scales', 'snip', 'measure', 'thresh']
|
|
)
|
|
save_path = '../figures/fig_invariance_thresh_lp_single.pdf'
|
|
exclude_zero = True
|
|
|
|
# GRAPH SETTINGS:
|
|
fig_kwargs = dict(
|
|
figsize=(32/2.54, 32/2.54),
|
|
)
|
|
super_grid_kwargs = dict(
|
|
nrows=None,
|
|
ncols=3,
|
|
wspace=0,
|
|
hspace=0,
|
|
left=0,
|
|
right=1,
|
|
bottom=0,
|
|
top=1,
|
|
)
|
|
input_rows = 1
|
|
snip_rows = 2
|
|
subfig_specs = dict(
|
|
input=(slice(input_rows), slice(-1)),
|
|
snip=[np.array([input_rows, input_rows + snip_rows]), slice(-1)],
|
|
big=(slice(None), -1),
|
|
)
|
|
snip_col_shift = -0.07
|
|
snip_grid_kwargs = dict(
|
|
nrows=len(stages),
|
|
ncols=None,
|
|
wspace=0.3,
|
|
hspace=0,
|
|
left=0.23 - snip_col_shift,
|
|
right=0.93,
|
|
bottom=0.15,
|
|
top=0.95,
|
|
height_ratios=[4, 1, 2]
|
|
)
|
|
input_grid_kwargs = dict(
|
|
nrows=1,
|
|
ncols=None,
|
|
wspace=snip_grid_kwargs['wspace'],
|
|
hspace=0,
|
|
left=snip_grid_kwargs['left'],
|
|
right=snip_grid_kwargs['right'],
|
|
bottom=0.15,
|
|
top=0.75,
|
|
)
|
|
big_grid_kwargs = dict(
|
|
nrows=2,
|
|
ncols=1,
|
|
wspace=0,
|
|
hspace=0.15,
|
|
left=0.2,
|
|
right=0.96,
|
|
bottom=0.05,
|
|
top=0.99
|
|
)
|
|
dist_inset_bounds = [1.02, 0, 0.2, 1]
|
|
|
|
# PLOT SETTINGS:
|
|
fs = dict(
|
|
lab_norm=16,
|
|
lab_tex=20,
|
|
letter=22,
|
|
tit_norm=16,
|
|
tit_tex=20,
|
|
bar=16,
|
|
)
|
|
colors = load_colors('../data/stage_colors.npz')
|
|
shade_factors = [0.2, -0.2]
|
|
lw = dict(
|
|
inv=1.5,
|
|
conv=1.5,
|
|
bi=0.1,
|
|
feat=3,
|
|
big=4,
|
|
thresh=1.5,
|
|
kern=2.5,
|
|
plateau=1.5,
|
|
)
|
|
xlabels = dict(
|
|
alpha='scale $\\alpha$',
|
|
sigma='$\\sigma_{\\text{adapt}}$',
|
|
)
|
|
ylabels = dict(
|
|
inv='$x_{\\text{adapt}}$\n$[\\text{dB}]$',
|
|
conv='$c_i$\n$[\\text{dB}]$',
|
|
bi='$b_i$',
|
|
feat='$f_i$',
|
|
big='$\\mu_{f_i}$',
|
|
)
|
|
xlab_alpha_kwargs = dict(
|
|
y=0.5,
|
|
fontsize=fs['lab_norm'],
|
|
ha='center',
|
|
va='bottom',
|
|
)
|
|
xlab_sigma_kwargs = dict(
|
|
y=0,
|
|
fontsize=fs['lab_tex'],
|
|
ha=xlab_alpha_kwargs['ha'],
|
|
va='bottom',
|
|
)
|
|
ylab_snip_kwargs = dict(
|
|
x=0.1,
|
|
fontsize=fs['lab_tex'],
|
|
rotation=0,
|
|
ha='center',
|
|
va='center',
|
|
)
|
|
ylab_super_kwargs = dict(
|
|
x=0,
|
|
fontsize=fs['lab_tex'],
|
|
ha='left',
|
|
va='center',
|
|
)
|
|
ylab_big_kwargs = dict(
|
|
x=0,
|
|
fontsize=fs['lab_tex'],
|
|
ha='center',
|
|
va='top',
|
|
)
|
|
ypad = dict(
|
|
inv=0.05,
|
|
conv=0.05,
|
|
big=0.1
|
|
)
|
|
yloc = dict(
|
|
inv=(2, 200),
|
|
conv=(0.02, 2),
|
|
bi=(1, 1),
|
|
feat=(1, 1),
|
|
big=0.2,
|
|
)
|
|
title_kwargs = dict(
|
|
x=0.5,
|
|
yref=1,
|
|
ha='center',
|
|
va='top',
|
|
fontsize=fs['tit_norm'],
|
|
)
|
|
letter_snip_kwargs = dict(
|
|
x=0,
|
|
y=1,
|
|
ha='left',
|
|
va='top',
|
|
fontsize=fs['letter'],
|
|
)
|
|
letter_big_kwargs = dict(
|
|
xref=0,
|
|
y=1,
|
|
ha='left',
|
|
va='top',
|
|
fontsize=fs['letter'],
|
|
)
|
|
kern_kwargs = dict(
|
|
c='k',
|
|
lw=lw['kern'],
|
|
)
|
|
dist_kwargs = dict(
|
|
c='k',
|
|
lw=1,
|
|
)
|
|
dist_fill_kwargs = dict(
|
|
color=colors['bi'],
|
|
lw=0.1,
|
|
)
|
|
thresh_kwargs = dict(
|
|
color='k',
|
|
lw=lw['thresh'],
|
|
ls='--',
|
|
zorder=3,
|
|
)
|
|
bar_time = 0.1
|
|
bar_kwargs = dict(
|
|
dur=bar_time,
|
|
y0=-0.5,
|
|
y1=-0.35,
|
|
xshift=1,
|
|
color='k',
|
|
lw=0,
|
|
clip_on=False,
|
|
text_pos=(-0.1, 0.5),
|
|
text_str=f'${int(1000 * bar_time)}\\,\\text{{ms}}$',
|
|
text_kwargs=dict(
|
|
fontsize=fs['bar'],
|
|
ha='right',
|
|
va='center',
|
|
)
|
|
)
|
|
leg_kwargs = dict(
|
|
ncols=2,
|
|
loc='center',
|
|
bbox_to_anchor=(0, 0.95, 1, 0.05),
|
|
frameon=False,
|
|
fontsize=fs['tit_norm'],
|
|
handlelength=1.5,
|
|
columnspacing=1,
|
|
)
|
|
plateau_settings = dict(
|
|
low=0.05,
|
|
high=0.95,
|
|
first=True,
|
|
last=True,
|
|
condense=None,
|
|
)
|
|
plateau_line_kwargs = dict(
|
|
lw=lw['plateau'],
|
|
ls='--',
|
|
zorder=1,
|
|
)
|
|
plateau_dot_kwargs = dict(
|
|
marker='o',
|
|
markersize=8,
|
|
markeredgewidth=1,
|
|
clip_on=False,
|
|
)
|
|
zoom_rel = np.array([0.5, 0.515])
|
|
|
|
# SUBSET SETTINGS:
|
|
kern_specs = np.array([
|
|
[1, 0.008],
|
|
[2, 0.004],
|
|
[3, 0.002],
|
|
])[np.array([1])]
|
|
|
|
# EXECUTION:
|
|
print(f'Processing {data_path}')
|
|
|
|
# Load invariance data:
|
|
noise_data, config = load_data(data_path, **load_kwargs)
|
|
pure_data, _ = load_data(data_path.replace('noise', 'pure'), **load_kwargs)
|
|
|
|
# Unpack shared variables:
|
|
scales = noise_data['scales']
|
|
plot_scales = noise_data['example_scales']
|
|
thresh_rel = noise_data['thresh_rel']
|
|
thresh_abs = noise_data['thresh_abs']
|
|
|
|
# Reduce to kernel subset and crop to zoom frame:
|
|
t_full = np.arange(noise_data['snip_conv'].shape[0]) / config['env_rate']
|
|
zoom_abs = zoom_rel * t_full[-1]
|
|
zoom_inds = (t_full >= zoom_abs[0]) & (t_full <= zoom_abs[1])
|
|
kern_ind = find_kern_specs(config['k_specs'], kerns=kern_specs)[0]
|
|
noise_data['snip_inv'] = noise_data['snip_inv'][zoom_inds, :]
|
|
noise_data['snip_conv'] = noise_data['snip_conv'][zoom_inds, kern_ind, :]
|
|
noise_data['snip_bi'] = noise_data['snip_bi'][zoom_inds, kern_ind, :, :]
|
|
noise_data['snip_feat'] = noise_data['snip_feat'][zoom_inds, kern_ind, :, :]
|
|
noise_data['measure_feat'] = noise_data['measure_feat'][:, kern_ind, :]
|
|
pure_data['measure_feat'] = pure_data['measure_feat'][:, kern_ind, :]
|
|
config['kernels'] = config['kernels'][:, kern_ind]
|
|
thresh_abs = thresh_abs[:, kern_ind]
|
|
t_full = np.arange(noise_data['snip_conv'].shape[0]) / config['env_rate']
|
|
|
|
if exclude_zero:
|
|
# Exclude zero scale:
|
|
inds = scales > 0
|
|
scales = scales[inds]
|
|
noise_data['measure_inv'] = noise_data['measure_inv'][inds]
|
|
noise_data['measure_feat'] = noise_data['measure_feat'][inds, :]
|
|
pure_data['measure_feat'] = pure_data['measure_feat'][inds, :]
|
|
|
|
# Get threshold-specific colors:
|
|
factors = np.linspace(*shade_factors, thresh_rel.size)
|
|
shaded = dict(
|
|
conv=shade_colors(colors['conv'], factors),
|
|
bi=shade_colors(colors['bi'], factors),
|
|
feat=shade_colors(colors['feat'], factors),
|
|
)
|
|
|
|
# Adjust grid parameters to loaded data:
|
|
super_grid_kwargs['nrows'] = snip_rows * thresh_rel.size + input_rows
|
|
input_grid_kwargs['ncols'] = plot_scales.size
|
|
snip_grid_kwargs['ncols'] = plot_scales.size
|
|
|
|
# Prepare overall graph:
|
|
fig = plt.figure(**fig_kwargs)
|
|
super_grid = fig.add_gridspec(**super_grid_kwargs)
|
|
|
|
# Prepare input snippet axes:
|
|
input_subfig = fig.add_subfigure(super_grid[subfig_specs['input']])
|
|
input_axes = add_snip_axes(input_subfig, input_grid_kwargs).ravel()
|
|
input_axes[0].yaxis.set_major_locator(plt.MultipleLocator(yloc['inv'][0]))
|
|
input_axes[1].yaxis.set_major_locator(plt.MultipleLocator(yloc['inv'][1]))
|
|
ylabel(input_axes[0], ylabels['inv'], transform=input_subfig.transSubfigure, **ylab_snip_kwargs)
|
|
for ax, scale in zip(input_axes, plot_scales):
|
|
title_subplot(ax, f'$\\alpha={strip_zeros(scale)}$', ref=input_subfig, **title_kwargs)
|
|
letter_subplot(input_subfig, 'a', **letter_snip_kwargs)
|
|
|
|
# Prepare snippet axes:
|
|
snip_subfigs, snip_axes = [], []
|
|
for i in range(thresh_rel.size):
|
|
subfig_spec = subfig_specs['snip'].copy()
|
|
subfig_spec[0] = slice(*(subfig_spec[0] + i * snip_rows))
|
|
snip_subfig = fig.add_subfigure(super_grid[*subfig_spec])
|
|
axes = add_snip_axes(snip_subfig, snip_grid_kwargs)
|
|
low_box = axes[-1, 0].get_position()
|
|
high_box = axes[0, 0].get_position()
|
|
[hide_axis(ax, 'left') for ax in axes[1:, 1]]
|
|
super_ylabel(f'$\\Theta_i={strip_zeros(thresh_rel[i])}\\cdot\\sigma_{{\\eta}}$',
|
|
snip_subfig, axes[-1, 0], axes[0, 0], **ylab_super_kwargs)
|
|
for (ax1, ax2), stage in zip(axes[:, :2], stages):
|
|
ax1.yaxis.set_major_locator(plt.MultipleLocator(yloc[stage][0]))
|
|
ax2.yaxis.set_major_locator(plt.MultipleLocator(yloc[stage][1]))
|
|
ylabel(ax1, ylabels[stage], transform=snip_subfig.transSubfigure, **ylab_snip_kwargs)
|
|
if i == thresh_rel.size - 1:
|
|
axes[-1, -1].set_xlim(t_full[0], t_full[-1])
|
|
time_bar(axes[-1, -1], **bar_kwargs)
|
|
snip_subfigs.append(snip_subfig)
|
|
snip_axes.append(axes)
|
|
letter_subplots(snip_subfigs, 'bcd', **letter_snip_kwargs)
|
|
|
|
# Prepare analysis axes:
|
|
big_subfig = fig.add_subfigure(super_grid[subfig_specs['big']])
|
|
big_grid = big_subfig.add_gridspec(**big_grid_kwargs)
|
|
|
|
alpha_ax = big_subfig.add_subplot(big_grid[0, 0])
|
|
alpha_ax.set_xlim(scales[0], scales[-1])
|
|
alpha_ax.set_xscale('symlog', linthresh=scales[scales > 0][0], linscale=0.5)
|
|
ylimits(pure_data['measure_feat'], alpha_ax, minval=0, pad=ypad['big'])
|
|
alpha_ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['big']))
|
|
xlabel(alpha_ax, xlabels['alpha'], **xlab_alpha_kwargs, transform=big_subfig)
|
|
ylabel(alpha_ax, ylabels['big'], transform=big_subfig.transSubfigure, **ylab_big_kwargs)
|
|
letter_subplot(alpha_ax, 'e', ref=big_subfig, **letter_big_kwargs)
|
|
|
|
sigma_ax = big_subfig.add_subplot(big_grid[1, 0])
|
|
sigma_ax.set_xlim(1, noise_data['measure_inv'].max())
|
|
sigma_ax.set_xscale('symlog', linthresh=scales[scales > 0][0], linscale=0.5)
|
|
ylimits(pure_data['measure_feat'], sigma_ax, minval=0, pad=ypad['big'])
|
|
sigma_ax.yaxis.set_major_locator(plt.MultipleLocator(yloc['big']))
|
|
xlabel(sigma_ax, xlabels['sigma'], **xlab_sigma_kwargs, transform=big_subfig)
|
|
ylabel(sigma_ax, ylabels['big'], transform=big_subfig.transSubfigure, **ylab_big_kwargs)
|
|
letter_subplot(sigma_ax, 'f', ref=big_subfig, **letter_big_kwargs)
|
|
|
|
# Plot intensity-adapted snippets:
|
|
plot_snippets(input_axes, t_full, noise_data['snip_inv'],
|
|
ypad=ypad['inv'], c=colors['inv'], lw=lw['inv'])
|
|
ylimits(noise_data['snip_inv'][:, 0], input_axes[0], pad=ypad['inv'])
|
|
|
|
# Indicate kernel waveform over 1st intensity-adapted snippet:
|
|
input_axes[0].plot(config['k_times'] + 0.5 * t_full[-1], config['kernels'], **kern_kwargs)
|
|
|
|
# Plot representation snippets per threshold:
|
|
for i, (subfig, axes) in enumerate(zip(snip_subfigs, snip_axes)):
|
|
dist_fill_kwargs['color'] = shaded['bi'][i]
|
|
|
|
# Plot kernel response snippets:
|
|
plot_snippets(axes[0, :], t_full, noise_data['snip_conv'], thresh=thresh_abs[i],
|
|
ypad=ypad['conv'], fill_kwargs=dist_fill_kwargs, c=shaded['conv'][i], lw=lw['conv'])
|
|
ylim_zoom = ylimits(noise_data['snip_conv'][:, 0], axes[0, 0],
|
|
pad=ypad['conv'], maxval=thresh_abs[-1])
|
|
|
|
# Indicate absolute threshold value:
|
|
handle = axes[0, 0].axhline(thresh_abs[i], **thresh_kwargs)
|
|
set_clip_box(handle, axes[0, 0], bounds=[[0, 0], [1, 1.05]])
|
|
|
|
# Plot kernel response distributions:
|
|
side_distributions(axes[0, :1], noise_data['snip_conv'][:, :1], dist_inset_bounds,
|
|
thresh_abs[i], nbins=50, limits=ylim_zoom, fill_kwargs=dist_fill_kwargs, **dist_kwargs)
|
|
side_distributions(axes[0, 1:], noise_data['snip_conv'][:, 1:], dist_inset_bounds,
|
|
thresh_abs[i], nbins=50, fill_kwargs=dist_fill_kwargs, **dist_kwargs)
|
|
|
|
# Plot binary snippets:
|
|
plot_bi_snippets(axes[1, :], t_full, noise_data['snip_bi'][:, :, i],
|
|
color=shaded['bi'][i], lw=lw['bi'])
|
|
|
|
# Plot feature snippets:
|
|
handles = plot_snippets(axes[2, :], t_full, noise_data['snip_feat'][:, :, i],
|
|
ymin=0, ymax=1, c=shaded['feat'][i], lw=lw['feat'])
|
|
[set_clip_box(h[0], ax, bounds=[[0, -0.05], [1, 1.05]]) for h, ax in zip(handles, axes[2, :])]
|
|
|
|
# Get saturation:
|
|
saturation_inds = []
|
|
for i in range(thresh_rel.size):
|
|
ind = get_saturation(noise_data['measure_feat'][:, i], **plateau_settings)[1]
|
|
saturation_inds.append(ind)
|
|
|
|
# Plot analysis results:
|
|
for ax, x in zip([alpha_ax, sigma_ax], [scales, noise_data['measure_inv']]):
|
|
# Plot pure-song analysis results:
|
|
handles = ax.plot(x, pure_data['measure_feat'], lw=lw['big'], ls='dotted')
|
|
[h.set_color(c) for h, c in zip(handles, shaded['feat'])]
|
|
|
|
# Plot noise-song analysis results:
|
|
handles = ax.plot(x, noise_data['measure_feat'], lw=lw['big'])
|
|
[h.set_color(c) for h, c in zip(handles, shaded['feat'])]
|
|
|
|
# Indicate threshold-specific saturation:
|
|
for i, ind in enumerate(saturation_inds):
|
|
color = shaded['feat'][i]
|
|
ax.plot(x[ind], 0, c='w', alpha=1, zorder=5.5, **plateau_dot_kwargs,
|
|
transform=ax.get_xaxis_transform())
|
|
ax.plot(x[ind], 0, mfc=color, mec='k', alpha=0.75, zorder=6,
|
|
**plateau_dot_kwargs, transform=ax.get_xaxis_transform())
|
|
ax.vlines(x[ind], ax.get_ylim()[0], noise_data['measure_feat'][ind, i],
|
|
color=color, **plateau_line_kwargs)
|
|
|
|
# Add proxy legend:
|
|
if ax == alpha_ax:
|
|
h1 = ax.plot([], [], c='k', lw=lw['big'], label='$\\alpha\\cdot s(t) + \\eta(t)$')[0]
|
|
h2 = ax.plot([], [], c='k', lw=lw['big'], ls='dotted', label='$\\alpha\\cdot s(t)$')[0]
|
|
ax.legend(handles=[h1, h2], **leg_kwargs)
|
|
|
|
if save_path is not None:
|
|
fig.savefig(save_path)
|
|
plt.show()
|
|
|
|
print('Done.')
|
|
embed()
|