Polished most of 2 new methods figures.

This commit is contained in:
j-hartling
2026-02-19 16:34:37 +01:00
parent 652621f782
commit 5afd073de9
13 changed files with 195 additions and 223 deletions

View File

@@ -1,10 +1,12 @@
from math import log
import plotstyle_plt
import glob
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from thunderhopper.modeltools import load_data
from color_functions import load_colors, shade_colors
from color_functions import load_colors
from IPython import embed
def prepare_fig(nrows, ncols, width=8, height=None, rheight=2,
@@ -45,14 +47,19 @@ def ylimits(ax, signal, minval=None, maxval=None, pad=0.05):
limits[1] += (limits[1] - limits[0]) * pad
return ax.set_ylim(limits)
def super_xlabel(label, fig, high_ax, low_ax, **kwargs):
x = (low_ax.get_position().x0 + high_ax.get_position().x1) / 2
fig.supxlabel(label, x=x, **kwargs)
def ylabel(ax, label, x=-0.23, fontsize=20):
ax.set_ylabel(label, fontsize=fontsize, rotation=0, ha='left', va='center')
ax.yaxis.set_label_coords(x, 0.5)
return None
def super_ylabel(label, fig, high_ax, low_ax, **kwargs):
def super_xlabel(label, fig, high_ax, low_ax, y=0.005, **kwargs):
x = (low_ax.get_position().x0 + high_ax.get_position().x1) / 2
fig.supxlabel(label, x=x, y=y, **kwargs)
return None
def super_ylabel(label, fig, high_ax, low_ax, x=0.005, **kwargs):
y = (low_ax.get_position().y0 + high_ax.get_position().y1) / 2
fig.supylabel(label, y=y, **kwargs)
fig.supylabel(label, x=x, y=y, **kwargs)
return None
def hide_axis(ax, side='bottom'):
@@ -63,10 +70,11 @@ def hide_axis(ax, side='bottom'):
return None
def plot_line(ax, time, signal, ymin=None, ymax=None, xmin=None, xmax=None,
xpad=None, ypad=0.05, **kwargs):
xpad=None, ypad=0.05, yloc=None, **kwargs):
handles = ax.plot(time, signal, **kwargs)
xlimits(ax, time, minval=xmin, maxval=xmax, pad=xpad)
ylimits(ax, signal, minval=ymin, maxval=ymax, pad=ypad)
ax.yaxis.set_major_locator(plt.MultipleLocator(yloc))
return handles
def plot_barcode(ax, time, binary, offset=0.5, xmin=None, xmax=None, **kwargs):
@@ -83,7 +91,8 @@ def plot_barcode(ax, time, binary, offset=0.5, xmin=None, xmax=None, **kwargs):
upper += offset + 1
xlimits(ax, time, minval=xmin, maxval=xmax)
ax.set_ylim(0, upper)
ax.axis('off')
hide_axis(ax, 'bottom')
hide_axis(ax, 'left')
return handles
def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs):
@@ -100,6 +109,21 @@ def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs):
def assign_colors(handles, types, colors):
for handle, type_id in zip(handles, types):
handle.set_color(colors[str(int(type_id))])
return None
def reorder_traces(handles, signal, zlow=2, zhigh=2.5):
inds = np.argsort(signal.std(axis=0))
zorders = np.linspace(zlow, zhigh, len(inds))[::-1]
for ind, z in zip(inds, zorders):
handles[ind].set_zorder(z)
return None
def letter_subplots(axes, labels='abcd', x=0.02, y=1, ha='left', va='bottom',
fontsize=16, fontweight='bold', **kwargs):
for ax, label in zip(axes, labels):
ax.text(x, y, label, transform=ax.transAxes, ha=ha, va=va,
fontsize=fontsize, fontweight=fontweight, **kwargs)
return None
# GENERAL SETTINGS:
@@ -118,32 +142,57 @@ grid_kwargs = dict(
wspace=0.15,
hspace=0.3,
left=0.1,
right=0.95,
right=0.99,
bottom=0.1,
top=0.95
)
ylabels = dict(
filt=r'$x_{\text{filt}}$',
env=r'$x_{\text{env}}$',
log=r'$x_{\text{dB}}$',
inv=r'$x_{\text{adapt}}$',
conv=r'$c_i$',
bi=r'$b_i$',
feat=r'$f_i$'
)
colors = load_colors('../data/stage_colors.npz')
lw_full = dict(
filt=0.25,
env=0.5,
log=0.5,
inv=0.5,
conv=0.5,
bi=1,
feat=1
conv=0.25,
bi=0,
feat=2
)
lw_zoom = dict(
filt=0.25,
env=0.75,
log=0.75,
inv=0.75,
conv=0.5,
bi=1,
feat=0.75
)
filt=0.5,
env=1,
log=1,
inv=1,
conv=1.5,
bi=0,
feat=2
)
loc_full = dict(
filt=0.2,
env=0.1,
log=20,
inv=10,
conv=1,
feat=1
)
loc_zoom = dict(
filt=0.1,
env=0.02,
log=20,
inv=10,
conv=0.2,
feat=1
)
zoom_rel = np.array([0.3, 0.4])
zoom_kwargs = dict(
color=(0.85, 0.85, 0.85),
color=3 * (0.85,),
zorder=0,
linewidth=0
)
@@ -152,7 +201,7 @@ zoom_kwargs = dict(
# [3, 0.032]
# ])
kern_types = np.array([1, -1, 2, -2, 3, -3, 4, -4])
kern_sigmas = np.array([0.008, 0.016, 0.032])
kern_sigmas = np.array([0.008, 0.032])
kernels = np.array([[k, s] for k in kern_types for s in kern_sigmas])
conv_colors = load_colors('../data/conv_colors.npz')
bi_colors = load_colors('../data/bi_colors.npz')
@@ -180,43 +229,44 @@ for data_path in data_paths:
# PART I: PREPROCESSING STAGE
fig, axes = prepare_fig(4, 2, **fig_kwargs, **grid_kwargs)
super_xlabel('time [s]', fig, axes[0, 0], axes[0, -1])
super_ylabel('amplitude', fig, axes[0, 0], axes[-1, 0])
# Bandpass-filtered signal:
ax_full, ax_zoom = axes[0, :]
plot_line(ax_full, t_full, data['filt'], c=colors['filt'], lw=lw_full['filt'])
plot_line(ax_zoom, t_zoom, data['filt'][zoom_mask], c=colors['filt'], lw=lw_zoom['filt'])
ylabel(ax_full, ylabels['filt'])
plot_line(ax_full, t_full, data['filt'], c=colors['filt'], lw=lw_full['filt'], yloc=loc_full['filt'])
plot_line(ax_zoom, t_zoom, data['filt'][zoom_mask], c=colors['filt'], lw=lw_zoom['filt'], yloc=loc_zoom['filt'])
hide_axis(ax_full, 'bottom')
hide_axis(ax_zoom, 'bottom')
hide_axis(ax_zoom, 'left')
# Signal envelope:
ax_full, ax_zoom = axes[1, :]
plot_line(ax_full, t_full, data['env'], ymin=0, c=colors['env'], lw=lw_full['env'])
plot_line(ax_zoom, t_zoom, data['env'][zoom_mask], ymin=0, c=colors['env'], lw=lw_zoom['env'])
ylabel(ax_full, ylabels['env'])
plot_line(ax_full, t_full, data['env'], ymin=0, c=colors['env'], lw=lw_full['env'], yloc=loc_full['env'])
plot_line(ax_zoom, t_zoom, data['env'][zoom_mask], ymin=0, c=colors['env'], lw=lw_zoom['env'], yloc=loc_zoom['env'])
hide_axis(ax_full, 'bottom')
hide_axis(ax_zoom, 'bottom')
hide_axis(ax_zoom, 'left')
# Logarithmic envelope:
ax_full, ax_zoom = axes[2, :]
plot_line(ax_full, t_full, data['log'], ymax=0, c=colors['log'], lw=lw_full['log'])
plot_line(ax_zoom, t_zoom, data['log'][zoom_mask], ymax=0, c=colors['log'], lw=lw_zoom['log'])
ylabel(ax_full, ylabels['log'])
plot_line(ax_full, t_full, data['log'], ymax=0, c=colors['log'], lw=lw_full['log'], yloc=loc_full['log'])
plot_line(ax_zoom, t_zoom, data['log'][zoom_mask], ymax=0, c=colors['log'], lw=lw_zoom['log'], yloc=loc_zoom['log'])
hide_axis(ax_full, 'bottom')
hide_axis(ax_zoom, 'bottom')
hide_axis(ax_zoom, 'left')
# Adapted envelope:
ax_full, ax_zoom = axes[3, :]
plot_line(ax_full, t_full, data['inv'], c=colors['inv'], lw=lw_full['inv'])
plot_line(ax_zoom, t_zoom, data['inv'][zoom_mask], c=colors['inv'], lw=lw_zoom['inv'])
hide_axis(ax_zoom, 'left')
ylabel(ax_full, ylabels['inv'])
plot_line(ax_full, t_full, data['inv'], c=colors['inv'], lw=lw_full['inv'], yloc=loc_full['inv'])
plot_line(ax_zoom, t_zoom, data['inv'][zoom_mask], c=colors['inv'], lw=lw_zoom['inv'], yloc=loc_zoom['inv'])
# Posthoc adjustments:
ax_full.set_xlim(t_full[0], t_full[-1])
ax_zoom.set_xlim(t_zoom[0], t_zoom[-1])
indicate_zoom(fig, axes[0, 0], axes[-1, 0], zoom_abs, **zoom_kwargs)
indicate_zoom(fig, axes[0, 1], axes[-1, 1], zoom_abs, **zoom_kwargs)
letter_subplots(axes[:, 0])
# fig.align_ylabels(axes[:, 0])
if save_name is not None:
fig.savefig(f'{save_name}_pre.pdf')
@@ -224,41 +274,45 @@ for data_path in data_paths:
# PART II: FEATURE EXTRACTION STAGE:
fig, axes = prepare_fig(3, 2, **fig_kwargs, **grid_kwargs)
super_xlabel('time [s]', fig, axes[0, 0], axes[0, -1])
super_ylabel('amplitude', fig, axes[0, 0], axes[-1, 0])
# Convolutional filter responses:
ax_full, ax_zoom = axes[0, :]
ylabel(ax_full, ylabels['conv'])
signal = data['conv'][:, kern_inds]
handles = plot_line(ax_full, t_full, signal, lw=lw_full['conv'])
handles = plot_line(ax_full, t_full, signal, lw=lw_full['conv'], yloc=loc_full['conv'])
assign_colors(handles, kernel_specs[:, 0], conv_colors)
handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], lw=lw_zoom['conv'])
reorder_traces(handles, signal)
handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], lw=lw_zoom['conv'], yloc=loc_zoom['conv'])
assign_colors(handles, kernel_specs[:, 0], conv_colors)
reorder_traces(handles, signal[zoom_mask, :])
hide_axis(ax_full, 'bottom')
hide_axis(ax_zoom, 'bottom')
hide_axis(ax_zoom, 'left')
# Binary responses:
ax_full, ax_zoom = axes[1, :]
ylabel(ax_full, ylabels['bi'])
signal = data['bi'][:, kern_inds]
handles = plot_barcode(ax_full, t_full, signal)
handles = plot_barcode(ax_full, t_full, signal, lw=lw_full['bi'])
assign_colors(handles, kernel_specs[:, 0], bi_colors)
handles = plot_barcode(ax_zoom, t_zoom, signal[zoom_mask, :])
handles = plot_barcode(ax_zoom, t_zoom, signal[zoom_mask, :], lw=lw_zoom['bi'])
assign_colors(handles, kernel_specs[:, 0], bi_colors)
# Finalized features:
ax_full, ax_zoom = axes[2, :]
ylabel(ax_full, ylabels['feat'])
signal = data['feat'][:, kern_inds]
handles = plot_line(ax_full, t_full, signal, ymin=0, ymax=1, c=colors['feat'], lw=lw_full['feat'])
handles = plot_line(ax_full, t_full, signal, ymin=0, ymax=1, c=colors['feat'], lw=lw_full['feat'], yloc=loc_full['feat'])
assign_colors(handles, kernel_specs[:, 0], feat_colors)
handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], ymin=0, ymax=1, c=colors['feat'], lw=lw_zoom['feat'])
handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], ymin=0, ymax=1, c=colors['feat'], lw=lw_zoom['feat'], yloc=loc_zoom['feat'])
assign_colors(handles, kernel_specs[:, 0], feat_colors)
hide_axis(ax_zoom, 'left')
# Posthoc adjustments:
ax_full.set_xlim(t_full[0], t_full[-1])
ax_zoom.set_xlim(t_zoom[0], t_zoom[-1])
indicate_zoom(fig, axes[0, 0], axes[-1, 0], zoom_abs, **zoom_kwargs)
indicate_zoom(fig, axes[0, 1], axes[-1, 1], zoom_abs, **zoom_kwargs)
letter_subplots(axes[:, 0])
# fig.align_ylabels(axes[:, 0])
if save_name is not None:
fig.savefig(f'{save_name}_feat.pdf')
plt.show()

View File

@@ -5,9 +5,8 @@ mpl.rcParams['text.usetex'] = False
# Font style:
mpl.rcParams['font.style'] = 'normal'
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['font.family'] = 'sans-serif'
mpl.rcParams['mathtext.fontset'] = 'cm'
mpl.rcParams['mathtext.default'] = 'regular'
# Font sizes:
mpl.rcParams['font.size'] = 14
@@ -41,7 +40,7 @@ mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.xmargin'] = 0
mpl.rcParams['axes.ymargin'] = 0
mpl.rcParams['axes.autolimit_mode'] = 'round_numbers'
mpl.rcParams['axes.labelpad'] = 3
mpl.rcParams['axes.labelpad'] = 5
# Major tick parameters:
mpl.rcParams['xtick.major.size'] = 5

View File

@@ -4,7 +4,7 @@ from color_functions import load_colors, shade_colors
# Settings:
stages = ['conv', 'bi', 'feat']
kern_types = np.array([1, -1, 2, -2, 3, -3, 4, -4])
shade_factors = np.linspace(-0.5, 0.5, kern_types.size)
shade_factors = np.linspace(-0.6, 0.2, kern_types.size)
# Main colors:
stage_colors = load_colors('../data/stage_colors.npz')