Built color selection infrastructure and assigned stage-specific colors (WIP).

Added 2 plots to methods (WIP).
This commit is contained in:
j-hartling
2026-02-17 16:46:02 +01:00
parent d78dcf4f4a
commit 49b2bdcfca
16 changed files with 631 additions and 173 deletions

View File

@@ -1,34 +1,123 @@
import glob
import plotstyle_plt
import glob
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from thunderhopper.modeltools import load_data
from IPython import embed
def prepare_fig(nrows, ncols, width=8, height=None, rheight=2,
left=0.01, right=0.95, bottom=0.01, top=0.95,
wspace=0.4, hspace=0.4):
if height is None:
height = rheight * nrows
fig = plt.figure(figsize=(width, height))
grid = fig.add_gridspec(nrows=nrows, ncols=ncols, wspace=wspace, hspace=hspace,
left=left, right=right, top=top, bottom=bottom)
axes = np.zeros((nrows, ncols), dtype=object)
for i, j in product(range(nrows), range(ncols)):
axes[i, j] = fig.add_subplot(grid[i, j])
axes[i, j].set_facecolor('none')
return fig, axes
def xlimits(ax, time, minval=None, maxval=None, pad=0.05):
limits = [minval, maxval]
if minval is None:
limits[0] = time[0]
if maxval is None:
limits[1] = time[-1]
if pad is not None and minval is None:
limits[0] -= (limits[1] - limits[0]) * pad
if pad is not None and maxval is None:
limits[1] += (limits[1] - limits[0]) * pad
return ax.set_xlim(limits)
def ylimits(ax, signal, minval=None, maxval=None, pad=0.05):
limits = [minval, maxval]
if minval is None:
limits[0] = signal.min()
if maxval is None:
limits[1] = signal.max()
if pad is not None and minval is None:
limits[0] -= (limits[1] - limits[0]) * pad
if pad is not None and maxval is None:
limits[1] += (limits[1] - limits[0]) * pad
return ax.set_ylim(limits)
def super_xlabel(label, fig, high_ax, low_ax, **kwargs):
x = (low_ax.get_position().x0 + high_ax.get_position().x1) / 2
fig.supxlabel(label, x=x, **kwargs)
return None
def super_ylabel(label, fig, high_ax, low_ax, **kwargs):
y = (low_ax.get_position().y0 + high_ax.get_position().y1) / 2
fig.supylabel(label, y=y, **kwargs)
return None
def hide_axis(ax, side='bottom'):
ax.spines[side].set_visible(False)
params = {side: False, 'label' + side: False}
ax.tick_params(axis='x' if side in ['top', 'bottom'] else 'y',
which='both', **params)
return None
def plot_line(ax, time, signal, ymin=None, ymax=None, xmin=None, xmax=None,
xpad=None, ypad=0.05, **kwargs):
ax.plot(time, signal, **kwargs)
xlimits(ax, time, minval=xmin, maxval=xmax, pad=xpad)
ylimits(ax, signal, minval=ymin, maxval=ymax, pad=ypad)
return None
def plot_barcode(ax, time, binary, offset=0.1, xmin=None, xmax=None, **kwargs):
if xmin is None:
xmin = time[0]
if xmax is None:
xmax = time[-1]
lower, upper = 0, 1
for i in range(binary.shape[1]):
ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs)
if i < binary.shape[1] - 1:
lower += offset + 1
upper += offset + 1
xlimits(ax, time, minval=xmin, maxval=xmax)
ax.set_ylim(0, upper)
ax.axis('off')
return None
def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs):
y0 = low_ax.get_position().y0
y1 = high_ax.get_position().y1
transform = low_ax.transData + fig.transFigure.inverted()
x0 = transform.transform((zoom_abs[0], 0))[0]
x1 = transform.transform((zoom_abs[1], 0))[0]
rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0,
transform=fig.transFigure, **kwargs)
fig.add_artist(rect)
return None
# GENERAL SETTINGS:
data_paths = glob.glob('../data/processed/*.npz')
stages = ['filt', 'env', 'log', 'inv',
'conv', 'bi', 'feat']
channel = 0
target = 'Omocestus_rufipes'
data_paths = glob.glob(f'../data/processed/{target}*.npz')
stages = ['filt', 'env', 'log', 'inv', 'conv', 'bi', 'feat']
save_name = '../figures/pathway_stages'
# PLOT SETTINGS:
fig_kwargs = dict(
figsize = np.array([16, 9]) * 0.75,
layout = 'constrained',
sharex = 'col',
sharey = 'row'
width=16 / 2.54 * 2,
height=6 / 2.54 * 2,
rheight=2 / 2.54 * 2,
)
zoom_rel = np.array([0.4, 0.6])
colors = dict(
filt='k',
env='k',
log='k',
inv='k',
conv='k',
bi='k',
feat='k'
grid_kwargs = dict(
wspace=0.15,
hspace=0.3,
left=0.1,
right=0.95,
bottom=0.1,
top=0.95
)
linewidths = dict(
colors = {s: c.item() for s, c in dict(np.load('../data/stage_colors.npz')).items()}
lw_full = dict(
filt=0.25,
env=0.5,
log=0.5,
@@ -37,11 +126,24 @@ linewidths = dict(
bi=1,
feat=1
)
lw_zoom = dict(
filt=0.25,
env=0.75,
log=0.75,
inv=0.75,
conv=0.5,
bi=1,
feat=0.75
)
zoom_rel = np.array([0.3, 0.4])
zoom_kwargs = dict(
color=(0.85, 0.85, 0.85),
zorder=0,
linewidth=0
)
# EXECUTION:
for data_path in data_paths:
if 'Gomphocerippus' in data_path:
continue
print(f'Processing {data_path}')
# Load overall data:
@@ -55,79 +157,80 @@ for data_path in data_paths:
# PART I: PREPROCESSING STAGE
fig, axes = plt.subplots(4, 2, **fig_kwargs)
fig.supylabel('amplitude', fontsize=plt.rcParams['axes.labelsize'])
fig.supxlabel('time [s]', fontsize=plt.rcParams['axes.labelsize'])
fig, axes = prepare_fig(4, 2, **fig_kwargs, **grid_kwargs)
super_xlabel('time [s]', fig, axes[0, 0], axes[0, -1])
super_ylabel('amplitude', fig, axes[0, 0], axes[-1, 0])
# Bandpass-filtered signal:
signal = data['filt'][:, channel]
ax_full, ax_zoom = axes[0, :]
c, lw = colors['filt'], linewidths['filt']
ax_full.plot(t_full, signal, c=c, lw=lw)
ax_zoom.plot(t_zoom, signal[zoom_mask], c=c, lw=lw)
ax_full.set_ylim(signal.min(), signal.max())
plot_line(ax_full, t_full, data['filt'], c=colors['filt'], lw=lw_full['filt'])
plot_line(ax_zoom, t_zoom, data['filt'][zoom_mask], c=colors['filt'], lw=lw_zoom['filt'])
hide_axis(ax_full, 'bottom')
hide_axis(ax_zoom, 'bottom')
hide_axis(ax_zoom, 'left')
# Signal envelope:
signal = data['env'][:, channel]
ax_full, ax_zoom = axes[1, :]
c, lw = colors['env'], linewidths['env']
ax_full.plot(t_full, signal, c=c, lw=lw)
ax_zoom.plot(t_zoom, signal[zoom_mask], c=c, lw=lw)
ax_full.set_ylim(0, signal.max())
plot_line(ax_full, t_full, data['env'], ymin=0, c=colors['env'], lw=lw_full['env'])
plot_line(ax_zoom, t_zoom, data['env'][zoom_mask], ymin=0, c=colors['env'], lw=lw_zoom['env'])
hide_axis(ax_full, 'bottom')
hide_axis(ax_zoom, 'bottom')
hide_axis(ax_zoom, 'left')
# Logarithmic envelope:
signal = data['log'][:, channel]
ax_full, ax_zoom = axes[2, :]
c, lw = colors['log'], linewidths['log']
ax_full.plot(t_full, signal, c=c, lw=lw)
ax_zoom.plot(t_zoom, signal[zoom_mask], c=c, lw=lw)
ax_full.set_ylim(signal.min(), 0)
plot_line(ax_full, t_full, data['log'], ymax=0, c=colors['log'], lw=lw_full['log'])
plot_line(ax_zoom, t_zoom, data['log'][zoom_mask], ymax=0, c=colors['log'], lw=lw_zoom['log'])
hide_axis(ax_full, 'bottom')
hide_axis(ax_zoom, 'bottom')
hide_axis(ax_zoom, 'left')
# Adapted envelope:
signal = data['inv'][:, channel]
ax_full, ax_zoom = axes[3, :]
c, lw = colors['inv'], linewidths['inv']
ax_full.plot(t_full, signal, c=c, lw=lw)
ax_zoom.plot(t_zoom, signal[zoom_mask], c=c, lw=lw)
ax_full.set_ylim(signal.min(), signal.max())
plot_line(ax_full, t_full, data['inv'], c=colors['inv'], lw=lw_full['inv'])
plot_line(ax_zoom, t_zoom, data['inv'][zoom_mask], c=colors['inv'], lw=lw_zoom['inv'])
hide_axis(ax_zoom, 'left')
# Posthoc adjustments:
ax_full.set_xlim(t_full[0], t_full[-1])
ax_zoom.set_xlim(t_zoom[0], t_zoom[-1])
indicate_zoom(fig, axes[0, 0], axes[-1, 0], zoom_abs, **zoom_kwargs)
indicate_zoom(fig, axes[0, 1], axes[-1, 1], zoom_abs, **zoom_kwargs)
if save_name is not None:
fig.savefig(f'{save_name}_pre.pdf')
# PART II: FEATURE EXTRACTION STAGE:
fig, axes = plt.subplots(3, 2, **fig_kwargs)
fig.supylabel('amplitude', fontsize=plt.rcParams['axes.labelsize'])
fig.supxlabel('time [s]', fontsize=plt.rcParams['axes.labelsize'])
fig, axes = prepare_fig(3, 2, **fig_kwargs, **grid_kwargs)
super_xlabel('time [s]', fig, axes[0, 0], axes[0, -1])
super_ylabel('amplitude', fig, axes[0, 0], axes[-1, 0])
# Convolutional filter responses:
signal = data['conv'][:, :, channel]
ax_full, ax_zoom = axes[0, :]
c, lw = colors['conv'], linewidths['conv']
ax_full.plot(t_full, signal, c=c, lw=lw)
ax_zoom.plot(t_zoom, signal[zoom_mask, :], c=c, lw=lw)
ax_full.set_ylim(signal.min(), signal.max())
plot_line(ax_full, t_full, data['conv'], c=colors['conv'], lw=lw_full['conv'])
plot_line(ax_zoom, t_zoom, data['conv'][zoom_mask, :], c=colors['conv'], lw=lw_zoom['conv'])
hide_axis(ax_full, 'bottom')
hide_axis(ax_zoom, 'bottom')
hide_axis(ax_zoom, 'left')
# Binary responses:
signal = data['bi'][:, :, channel]
ax_full, ax_zoom = axes[1, :]
c, lw = colors['bi'], linewidths['bi']
ax_full.plot(t_full, signal, c=c, lw=lw)
ax_zoom.plot(t_zoom, signal[zoom_mask, :], c=c, lw=lw)
ax_full.set_ylim(signal.min(), signal.max())
plot_barcode(ax_full, t_full, data['bi'], color=colors['bi'])
plot_barcode(ax_zoom, t_zoom, data['bi'][zoom_mask, :], color=colors['bi'])
# Finalized features:
signal = data['feat'][:, :, channel]
ax_full, ax_zoom = axes[2, :]
c, lw = colors['feat'], linewidths['feat']
ax_full.plot(t_full, signal, c=c, lw=lw)
ax_zoom.plot(t_zoom, signal[zoom_mask, :], c=c, lw=lw)
ax_full.set_ylim(0, 1)
plot_line(ax_full, t_full, data['feat'], ymin=0, ymax=1, c=colors['feat'], lw=lw_full['feat'])
plot_line(ax_zoom, t_zoom, data['feat'][zoom_mask, :], ymin=0, ymax=1, c=colors['feat'], lw=lw_zoom['feat'])
hide_axis(ax_zoom, 'left')
# Posthoc adjustments:
ax_full.set_xlim(t_full[0], t_full[-1])
ax_zoom.set_xlim(t_zoom[0], t_zoom[-1])
indicate_zoom(fig, axes[0, 0], axes[-1, 0], zoom_abs, **zoom_kwargs)
indicate_zoom(fig, axes[0, 1], axes[-1, 1], zoom_abs, **zoom_kwargs)
if save_name is not None:
fig.savefig(f'{save_name}_feat.pdf')
plt.show()

View File

@@ -8,15 +8,18 @@ mpl.rcParams['font.style'] = 'normal'
mpl.rcParams['font.family'] = 'serif'
mpl.rcParams['mathtext.fontset'] = 'cm'
mpl.rcParams['mathtext.default'] = 'regular'
# Font sizes:
mpl.rcParams['font.size'] = 14
mpl.rcParams['figure.titlesize'] = 15
mpl.rcParams['figure.labelsize'] = 14
mpl.rcParams['axes.labelsize'] = 14
mpl.rcParams['axes.titlesize'] = 14
mpl.rcParams['xtick.labelsize'] = 13
mpl.rcParams['ytick.labelsize'] = 13
mpl.rcParams['legend.fontsize'] = 14
mpl.rcParams['legend.title_fontsize'] = 14
# Font weights:
single_weight = ['normal', 'bold'][0]
mpl.rcParams['font.weight'] = single_weight
@@ -32,6 +35,7 @@ mpl.rcParams['figure.raise_window'] = True
# mpl.rcParams['savefig.dpi'] = 500
# Axes parameters:
mpl.rcParams['axes.linewidth'] = 1.5
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.xmargin'] = 0
@@ -39,10 +43,20 @@ mpl.rcParams['axes.ymargin'] = 0
mpl.rcParams['axes.autolimit_mode'] = 'round_numbers'
mpl.rcParams['axes.labelpad'] = 3
# Tick parameters:
mpl.rcParams['xtick.major.pad'] = 5
mpl.rcParams['ytick.major.pad'] = 5
# Major tick parameters:
mpl.rcParams['xtick.major.size'] = 5
mpl.rcParams['xtick.major.width'] = 1.5
mpl.rcParams['xtick.major.pad'] = 3.5
mpl.rcParams['ytick.major.size'] = 5
mpl.rcParams['ytick.major.width'] = 1.5
mpl.rcParams['ytick.major.pad'] = 3.5
# Minor tick parameters:
mpl.rcParams['xtick.minor.size'] = 2
mpl.rcParams['xtick.minor.width'] = 1
mpl.rcParams['ytick.minor.size'] = 2
mpl.rcParams['ytick.minor.width'] = 1
# Legend parameters:
mpl.rcParams['legend.frameon'] = False
mpl.rcParams['legend.scatterpoints'] = 3
mpl.rcParams['legend.scatterpoints'] = 3

View File

@@ -0,0 +1,61 @@
import numpy as np
from thunderhopper.filetools import search_files, crop_paths
from thunderhopper.model import configuration, process_signal
from thunderhopper.modeltools import load_data
from IPython import embed
## SETTINGS:
# General:
overwrite = True
input_folder = '../data/raw/'
output_folder = '../data/processed/'
stages = ['raw', 'filt', 'env', 'log', 'inv', 'conv', 'bi', 'feat', 'norm']
if True:
# Overwrites edited:
stages.append('songs')
# Interactivity:
reload_saved = False
gui = False
# Processing:
env_rate = 44100.0
feat_rate = 44100.0
sigmas = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
types = [1, -1, 2, -2, 3, -3, 4, -4, 5, -5,
6, -6, 7, -7, 8, -8, 9, -9, 10, -10]
config = configuration(env_rate, feat_rate, types=types, sigmas=sigmas)
config.update({
'channel': 0,
'rate_ratio': None,
'env_fcut': 250,
'inv_fcut': 5,
'feat_thresh': np.load('../data/kernel_thresholds.npy') * 0.1,
'feat_fcut': 0.5,
'label_channels': 0,
'label_thresh': 0.5,
})
## PREPARATION:
# Fetch WAV recording files:
input_paths = search_files(ext='wav', dir=input_folder)
path_names = crop_paths(input_paths)
# PROCESSING:
# Run processing pipeline:
for path, name in zip(input_paths, path_names):
print('Processing:', name)
# Fetch and store representations:
save = None if output_folder is None else output_folder + f'{name}.npz'
process_signal(config, stages, path, save=save,
label_edit=gui, overwrite=overwrite)
# Cross-control:
if reload_saved:
data, params = load_data(save, stages, ['songs', 'noise'])
embed()
print('Done.')

195
python/save_stage_colors.py Normal file
View File

@@ -0,0 +1,195 @@
import numpy as np
import matplotlib.pyplot as plt
from tkinter.colorchooser import askcolor
from IPython import embed
def is_hex_color(color):
if isinstance(color, str) and color.startswith('#') and len(color) == 7:
try:
int(color[1:], 16)
return True
except ValueError:
return False
return False
def is_rgb_color(color):
if isinstance(color, (tuple, list)) and len(color) == 3:
return all(isinstance(c, int) and 0 <= c <= 255 for c in color)
return False
def hex_to_rgb(color):
color = color.lstrip('#')
return tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
def rgb_to_hex(color):
return '#{:02x}{:02x}{:02x}'.format(*color)
def whiten_color(color, factors):
is_hex = is_hex_color(color)
if is_hex:
color = hex_to_rgb(color)
elif not is_rgb_color(color):
raise ValueError('Color format must be hex string or RGB tuple.')
whitened = tuple(min(255, int(c + (255 - c) * f)) for c, f in zip(color, factors))
return rgb_to_hex(whitened) if is_hex else whitened
def color_selector(n=5, colors=None, save=None, labels=None, hex=True,
back_color='w', edge_color='k', edge_width=2):
def pick_color(color='#ffffff', hex=True):
color = askcolor(initialcolor=color)[hex]
return color
def update_color(artists, color):
for i, artist in enumerate(artists):
if isinstance(artist, plt.Rectangle):
# Update patch colors:
artist.set_fc(color)
if i in [0, 1]:
artist.set_ec(color)
elif isinstance(artist, plt.Line2D):
# Update marker colors:
artist.set_mfc(color)
if i in [0, 1]:
artist.set_mec(color)
fig.canvas.draw()
return None
# Prepare colors:
if colors is None:
colors = ['#ffffff' for _ in range(n)]
start_colors = colors.copy()
n = len(colors)
plt.ion()
# Prepare graphical interface:
fig, ax = plt.subplots(figsize=(10, 2), layout='constrained')
ax.set_facecolor(back_color)
ax.set_xlim(0, n)
ax.set_ylim(0, 4)
ax.axis('off')
# Add different color indicators:
artist_groups, group_inds = {}, {}
for i, color in enumerate(colors):
# Color on background, uniform edge:
p1 = ax.add_patch(plt.Rectangle((i, 0), 1, 1, fc=color, ec=color))
l1 = ax.plot(i + 0.5, 1.5, marker='o', ms=20, mfc=color, mec=color)[0]
# Color on background, black edge:
p2 = ax.add_patch(plt.Rectangle((i, 3), 1, 1, fc=color, ec=edge_color, lw=edge_width))
l2 = ax.plot(i + 0.5, 2.5, marker='o', ms=20, mfc=color, mec=edge_color, mew=edge_width)[0]
# Update artist mappings:
handles = [p1, l1, l2, p2]
artist_groups[i] = handles
for handle in handles:
handle.set_picker(True)
group_inds[handle] = i
# Initialize memory variables for avoiding assignments:
swap_groups, swap_colors = [None, None], [None, None]
focus_group, focus_color = [None], [None, None]
# Interactivity:
def on_key(event):
# Abort and retry:
if event.key == 'q':
print('\nRestarting with original settings...')
plt.close(fig)
return color_selector(n, start_colors, save, labels, hex,
back_color, edge_color, edge_width)
# Exit without saving:
elif event.key == 'escape':
print('\nExiting without saving...')
plt.close(fig)
return None
# Accept and save colors to file:
elif event.key == 'enter' and save is not None:
if labels is not None:
# Convert to file dictionary and write to npz file:
data = {str(label): col for label, col in zip(labels, colors)}
np.savez(f'{save}.npz', **data)
else:
# Write array (n, 3) or (n,) to npy file:
np.save(f'{save}.npy', np.array(colors))
print(f'\nSaved {n} colors to file: {save}')
plt.close(fig)
return None
def on_pick(event):
# Pick color for chosen group:
if event.mouseevent.button == 1:
# Update memory variables:
focus_group[0] = group_inds[event.artist]
focus_color[0] = colors[focus_group[0]]
# Trigger color picker dialogue:
focus_color[1] = pick_color(focus_color[0], hex)
if focus_color[1] is not None:
# Apply, update, and report only if dialogue was not cancelled:
update_color(artist_groups[focus_group[0]], focus_color[1])
colors[focus_group[0]] = focus_color[1]
print(f'\nUpdated color {focus_group[0] + 1} / {n}: {focus_color[1]}')
# Select 1st color to swap:
elif event.mouseevent.button == 3:
# Update memory variables and report:
swap_groups[0] = group_inds[event.artist]
swap_colors[0] = colors[swap_groups[0]]
print(f'\nSelected 1st swap color: {swap_groups[0] + 1} / {n}')
# Select 2nd color to swap and execute:
elif swap_groups[0] is not None and event.mouseevent.button == 2:
# Update memory variables and report:
swap_groups[1] = group_inds[event.artist]
swap_colors[1] = colors[swap_groups[1]]
print(f'Selected 2nd swap color: {swap_groups[1] + 1} / {n}')
print(f'Swapping colors: {swap_groups[0] + 1} <-> {swap_groups[1] + 1}')
# Swap colors:
for i in [0, 1]:
# Apply to artist group and update global color list:
update_color(artist_groups[swap_groups[i]], swap_colors[1 - i])
colors[swap_groups[i]] = swap_colors[1 - i]
# Reset memory variables:
swap_groups[0], swap_groups[1] = None, None
swap_colors[0], swap_colors[1] = None, None
# Establish interactivity:
plt.connect('key_press_event', on_key)
plt.connect('pick_event', on_pick)
plt.ioff()
return colors
# Settings:
stages = ['filt', 'env', 'log', 'inv', 'conv', 'bi', 'feat']
file_name = None#'../data/stage_colors'
colors = None
if True:
colors = dict(np.load('../data/stage_colors.npz'))
colors = [colors[stage].item() for stage in colors.keys()]
# Execution:
colors = color_selector(len(stages), colors, save=file_name, labels=stages, hex=True)
plt.show()
embed()