295 lines
8.4 KiB
Python
295 lines
8.4 KiB
Python
import plotstyle_plt
|
|
import string
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from thunderhopper.filetools import search_files
|
|
from thunderhopper.modeltools import load_data
|
|
from thunderhopper.filtertools import find_kern_specs
|
|
from color_functions import load_colors
|
|
from plot_functions import hide_ticks, ylabel, super_xlabel, letter_subplots,\
|
|
ylimits, title_subplot
|
|
from misc_functions import exclude_zero_scale, reduce_kernel_set
|
|
from IPython import embed
|
|
|
|
# GENERAL SETTINGS:
|
|
target_species = [
|
|
'Chorthippus_biguttulus',
|
|
'Chorthippus_mollis',
|
|
'Chrysochraon_dispar',
|
|
'Euchorthippus_declivus',
|
|
'Gomphocerippus_rufus',
|
|
'Omocestus_rufipes',
|
|
'Pseudochorthippus_parallelus',
|
|
][5]
|
|
modes = [
|
|
'unnormed',
|
|
'norm-base',
|
|
'norm-min',
|
|
'norm-max',
|
|
]
|
|
full_folder = '../data/inv/full/condensed/'
|
|
short_folder = '../data/inv/short/condensed/'
|
|
save_path = '../figures/fig_invariance_full_short.pdf'
|
|
load_kwargs = dict(
|
|
files=['scales', 'mean_feat', 'sd_feat'],
|
|
keywords=['thresh'],
|
|
)
|
|
|
|
# ANALYSIS SETTINGS:
|
|
exclude_zero = True
|
|
scale_subset_kwargs = dict(
|
|
combis=[['mean', 'sd'], ['feat']],
|
|
)
|
|
kern_subset_kwargs = dict(
|
|
combis=[['mean', 'sd'], ['feat']],
|
|
keys=['thresh_abs'],
|
|
)
|
|
thresh_rel = np.array([0, 0.5, 1, 1.5, 2, 2.5, 3])
|
|
percentiles = np.array([
|
|
[25, 75],
|
|
# [0, 100],
|
|
])
|
|
|
|
# SUBSET SETTINGS:
|
|
types = np.array([1, 2, 3])
|
|
# types = [1, -1, 2, -2, 3, -3, 4, -4, 5, -5, 6, -6, 7, -7, 8, -8, 9, -9, 10, -10]
|
|
sigmas = np.array([0.001, 0.002, 0.004, 0.008, 0.016, 0.032])
|
|
# sigmas = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
|
|
kernels = None
|
|
reduce_kernels = any(var is not None for var in [kernels, types, sigmas])
|
|
|
|
# GRAPH SETTINGS:
|
|
fig_kwargs = dict(
|
|
figsize=(32/2.54, 32/2.54),
|
|
)
|
|
super_grid_kwargs = dict(
|
|
nrows=1,
|
|
ncols=2,
|
|
wspace=0,
|
|
hspace=0,
|
|
left=0,
|
|
right=1,
|
|
bottom=0,
|
|
top=1,
|
|
)
|
|
subfig_specs = dict(
|
|
full=(0, 0),
|
|
short=(0, 1),
|
|
)
|
|
col_width = 0.85
|
|
col_rest = 1 - col_width
|
|
full_grid_kwargs = dict(
|
|
nrows=len(modes),
|
|
ncols=1,
|
|
wspace=0,
|
|
hspace=0.1,
|
|
left=col_rest,
|
|
right=1,
|
|
bottom=0.05,
|
|
top=0.9
|
|
)
|
|
short_grid_kwargs = dict(
|
|
nrows=len(modes),
|
|
ncols=1,
|
|
wspace=full_grid_kwargs['wspace'],
|
|
hspace=full_grid_kwargs['hspace'],
|
|
left=col_rest / 2,
|
|
right=1 - col_rest / 2,
|
|
bottom=full_grid_kwargs['bottom'],
|
|
top=full_grid_kwargs['top']
|
|
)
|
|
|
|
# PLOT SETTINGS:
|
|
fs = dict(
|
|
lab_norm=16,
|
|
lab_tex=20,
|
|
letter=22,
|
|
tit_norm=16,
|
|
tit_tex=20,
|
|
bar=16,
|
|
)
|
|
colors = load_colors('../data/stage_colors.npz')
|
|
feat_colors = load_colors('../data/feat_colors_all.npz')
|
|
lw = dict(
|
|
feat=3,
|
|
plateau=1.5
|
|
)
|
|
line_kwargs = dict(
|
|
lw=lw['feat']
|
|
)
|
|
fill_kwargs = dict(
|
|
alpha=0.15
|
|
)
|
|
xlabels = dict(
|
|
super='scale $\\alpha$',
|
|
)
|
|
ylabels = {
|
|
'unnormed': '$\\mu_{f_i}$',
|
|
'norm-base': '$\\mu_{f_i}\\,/\\,\\mu_{f_i}\\,[\\,\\eta\\,]$',
|
|
'norm-min': '$\\mu_{f_i}\\,/\\,\\min\\,[\\,\\mu_{f_i}\\,]$',
|
|
'norm-max': '$\\mu_{f_i}\\,/\\,\\max\\,[\\,\\mu_{f_i}\\,]$'
|
|
}
|
|
xlab_kwargs = dict(
|
|
y=0,
|
|
fontsize=fs['lab_norm'],
|
|
ha='center',
|
|
va='bottom',
|
|
)
|
|
ylab_kwargs = dict(
|
|
x=0,
|
|
fontsize=fs['lab_tex'],
|
|
ha='center',
|
|
va='top',
|
|
)
|
|
ylims = {
|
|
'unnormed': [0, 1],
|
|
'norm-base': [0, None],
|
|
'norm-min': [0, None],
|
|
'norm-max': [0, 1]
|
|
}
|
|
yloc = {
|
|
'unnormed': 0.5,
|
|
'norm-base': 0.5,
|
|
'norm-min': 0.5,
|
|
'norm-max': 0.5
|
|
}
|
|
title_kwargs = dict(
|
|
x=0.5,
|
|
y=1,
|
|
ha='center',
|
|
va='bottom',
|
|
fontsize=fs['tit_norm'],
|
|
)
|
|
titles = dict(
|
|
full='Including $x_\\text{dB}$',
|
|
short='Excluding $x_\\text{dB}$',
|
|
)
|
|
letter_kwargs = dict(
|
|
xref=0.01,
|
|
y=1,
|
|
ha='left',
|
|
va='center',
|
|
fontsize=fs['letter'],
|
|
)
|
|
letters = dict(
|
|
full=string.ascii_lowercase[0::2],
|
|
short=string.ascii_lowercase[1::2],
|
|
)
|
|
plateau_settings = dict(
|
|
low=0.05,
|
|
high=0.95,
|
|
first=True,
|
|
last=True,
|
|
condense=None,
|
|
)
|
|
plateau_line_kwargs = dict(
|
|
lw=lw['plateau'],
|
|
ls='--',
|
|
zorder=1,
|
|
)
|
|
plateau_dot_kwargs = dict(
|
|
marker='o',
|
|
markersize=8,
|
|
markeredgewidth=1,
|
|
clip_on=False,
|
|
)
|
|
|
|
# EXECUTION:
|
|
|
|
# Prepare overall graph:
|
|
fig = plt.figure(**fig_kwargs)
|
|
super_grid = fig.add_gridspec(**super_grid_kwargs)
|
|
|
|
# Prepare full analysis axes:
|
|
full_subfig = fig.add_subfigure(super_grid[*subfig_specs['full']])
|
|
full_grid = full_subfig.add_gridspec(**full_grid_kwargs)
|
|
full_axes = np.zeros((len(modes),), dtype=object)
|
|
for i, mode in enumerate(modes):
|
|
full_axes[i] = full_subfig.add_subplot(full_grid[i, 0])
|
|
# full_axes[i].yaxis.set_major_locator(plt.MultipleLocator(yloc[mode]))
|
|
full_axes[i].set_xscale('symlog', linthresh=0.01, linscale=0.5)
|
|
ylabel(full_axes[i], ylabels[mode], transform=full_subfig.transSubfigure, **ylab_kwargs)
|
|
if i == 0:
|
|
title_subplot(full_axes[i], titles['full'], **title_kwargs)
|
|
if i < full_grid_kwargs['nrows'] - 1:
|
|
hide_ticks(full_axes[i], 'bottom')
|
|
letter_subplots(full_axes, letters['full'], ref=full_subfig, **letter_kwargs)
|
|
|
|
# Prepare short analysis axes:
|
|
short_subfig = fig.add_subfigure(super_grid[*subfig_specs['short']])
|
|
short_grid = short_subfig.add_gridspec(**short_grid_kwargs)
|
|
short_axes = np.zeros((len(modes),), dtype=object)
|
|
for i, mode in enumerate(modes):
|
|
short_axes[i] = short_subfig.add_subplot(short_grid[i, 0])
|
|
# short_axes[i].yaxis.set_major_locator(plt.MultipleLocator(yloc[mode]))
|
|
short_axes[i].set_xscale('symlog', linthresh=0.01, linscale=0.5)
|
|
hide_ticks(short_axes[i], 'left')
|
|
if i == 0:
|
|
title_subplot(short_axes[i], titles['short'], **title_kwargs)
|
|
if i < short_grid_kwargs['nrows'] - 1:
|
|
hide_ticks(short_axes[i], 'bottom')
|
|
letter_subplots(short_axes, letters['short'], ref=short_subfig, **letter_kwargs)
|
|
|
|
super_xlabel(xlabels['super'], fig, full_axes[-1], short_axes[-1],
|
|
left_fig=full_subfig, right_fig=short_subfig, **xlab_kwargs)
|
|
|
|
# Run through normalization modes:
|
|
for mode, full_ax, short_ax in zip(modes, full_axes, short_axes):
|
|
|
|
# Load invariance data:
|
|
full_path = search_files(target_species, incl=mode, dir=full_folder)[0]
|
|
short_path = search_files(target_species, incl=mode, dir=short_folder)[0]
|
|
full_data, config = load_data(full_path, **load_kwargs)
|
|
short_data, _ = load_data(short_path, **load_kwargs)
|
|
|
|
# Reduce datasets:
|
|
if reduce_kernels:
|
|
kern_inds = find_kern_specs(config['k_specs'], kernels, types, sigmas)
|
|
full_data = reduce_kernel_set(full_data, kern_inds, **kern_subset_kwargs)
|
|
short_data = reduce_kernel_set(short_data, kern_inds, **kern_subset_kwargs)
|
|
config['k_specs'] = config['k_specs'][kern_inds, :]
|
|
config['kernels'] = config['kernels'][:, kern_inds]
|
|
if exclude_zero:
|
|
full_data = exclude_zero_scale(full_data, **scale_subset_kwargs)
|
|
short_data = exclude_zero_scale(short_data, **scale_subset_kwargs)
|
|
|
|
# Average over recordings:
|
|
full_measure = full_data['mean_feat'].mean(axis=-1)
|
|
short_measure = short_data['mean_feat'].mean(axis=-1)
|
|
|
|
# Condense over kernels:
|
|
full_median = np.nanmedian(full_measure, axis=1)
|
|
full_spread = np.nanpercentile(full_measure, percentiles, axis=1)
|
|
short_median = np.nanmedian(short_measure, axis=1)
|
|
short_spread = np.nanpercentile(short_measure, percentiles, axis=1)
|
|
|
|
# Determine shared ylims:
|
|
if None in ylims[mode]:
|
|
min_val, max_val = ylims[mode]
|
|
full_limits = ylimits(full_median, minval=min_val, maxval=max_val)
|
|
short_limits = ylimits(short_median, minval=min_val, maxval=max_val)
|
|
ylims[mode] = [min(full_limits[0], short_limits[0]),
|
|
max(full_limits[1], short_limits[1])]
|
|
if np.inf in ylims[mode]:
|
|
embed()
|
|
|
|
# Plot full analysis results:
|
|
for i, thresh in enumerate(thresh_rel):
|
|
full_ax.plot(full_data['scales'], full_median[:, i], lw=lw['feat'])
|
|
for spread in full_spread[:, :, :, i]:
|
|
full_ax.fill_between(full_data['scales'], *spread, **fill_kwargs)
|
|
full_ax.set_xlim(full_data['scales'][0], full_data['scales'][-1])
|
|
full_ax.set_ylim(ylims[mode])
|
|
|
|
# Plot short analysis results:
|
|
for i, thresh in enumerate(thresh_rel):
|
|
short_ax.plot(short_data['scales'], short_median[:, i], lw=lw['feat'])
|
|
for spread in short_spread[:, :, :, i]:
|
|
short_ax.fill_between(short_data['scales'], *spread, **fill_kwargs)
|
|
short_ax.set_xlim(short_data['scales'][0], short_data['scales'][-1])
|
|
short_ax.set_ylim(ylims[mode])
|
|
|
|
plt.show()
|
|
|
|
|