Built color selection infrastructure and assigned stage-specific colors (WIP).
Added 2 plots to methods (WIP).
This commit is contained in:
@@ -1,34 +1,123 @@
|
||||
import glob
|
||||
import plotstyle_plt
|
||||
import glob
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from itertools import product
|
||||
from thunderhopper.modeltools import load_data
|
||||
from IPython import embed
|
||||
|
||||
def prepare_fig(nrows, ncols, width=8, height=None, rheight=2,
|
||||
left=0.01, right=0.95, bottom=0.01, top=0.95,
|
||||
wspace=0.4, hspace=0.4):
|
||||
if height is None:
|
||||
height = rheight * nrows
|
||||
fig = plt.figure(figsize=(width, height))
|
||||
grid = fig.add_gridspec(nrows=nrows, ncols=ncols, wspace=wspace, hspace=hspace,
|
||||
left=left, right=right, top=top, bottom=bottom)
|
||||
axes = np.zeros((nrows, ncols), dtype=object)
|
||||
for i, j in product(range(nrows), range(ncols)):
|
||||
axes[i, j] = fig.add_subplot(grid[i, j])
|
||||
axes[i, j].set_facecolor('none')
|
||||
return fig, axes
|
||||
|
||||
def xlimits(ax, time, minval=None, maxval=None, pad=0.05):
|
||||
limits = [minval, maxval]
|
||||
if minval is None:
|
||||
limits[0] = time[0]
|
||||
if maxval is None:
|
||||
limits[1] = time[-1]
|
||||
if pad is not None and minval is None:
|
||||
limits[0] -= (limits[1] - limits[0]) * pad
|
||||
if pad is not None and maxval is None:
|
||||
limits[1] += (limits[1] - limits[0]) * pad
|
||||
return ax.set_xlim(limits)
|
||||
|
||||
def ylimits(ax, signal, minval=None, maxval=None, pad=0.05):
|
||||
limits = [minval, maxval]
|
||||
if minval is None:
|
||||
limits[0] = signal.min()
|
||||
if maxval is None:
|
||||
limits[1] = signal.max()
|
||||
if pad is not None and minval is None:
|
||||
limits[0] -= (limits[1] - limits[0]) * pad
|
||||
if pad is not None and maxval is None:
|
||||
limits[1] += (limits[1] - limits[0]) * pad
|
||||
return ax.set_ylim(limits)
|
||||
|
||||
def super_xlabel(label, fig, high_ax, low_ax, **kwargs):
|
||||
x = (low_ax.get_position().x0 + high_ax.get_position().x1) / 2
|
||||
fig.supxlabel(label, x=x, **kwargs)
|
||||
return None
|
||||
|
||||
def super_ylabel(label, fig, high_ax, low_ax, **kwargs):
|
||||
y = (low_ax.get_position().y0 + high_ax.get_position().y1) / 2
|
||||
fig.supylabel(label, y=y, **kwargs)
|
||||
return None
|
||||
|
||||
def hide_axis(ax, side='bottom'):
|
||||
ax.spines[side].set_visible(False)
|
||||
params = {side: False, 'label' + side: False}
|
||||
ax.tick_params(axis='x' if side in ['top', 'bottom'] else 'y',
|
||||
which='both', **params)
|
||||
return None
|
||||
|
||||
def plot_line(ax, time, signal, ymin=None, ymax=None, xmin=None, xmax=None,
|
||||
xpad=None, ypad=0.05, **kwargs):
|
||||
ax.plot(time, signal, **kwargs)
|
||||
xlimits(ax, time, minval=xmin, maxval=xmax, pad=xpad)
|
||||
ylimits(ax, signal, minval=ymin, maxval=ymax, pad=ypad)
|
||||
return None
|
||||
|
||||
def plot_barcode(ax, time, binary, offset=0.1, xmin=None, xmax=None, **kwargs):
|
||||
if xmin is None:
|
||||
xmin = time[0]
|
||||
if xmax is None:
|
||||
xmax = time[-1]
|
||||
lower, upper = 0, 1
|
||||
for i in range(binary.shape[1]):
|
||||
ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs)
|
||||
if i < binary.shape[1] - 1:
|
||||
lower += offset + 1
|
||||
upper += offset + 1
|
||||
xlimits(ax, time, minval=xmin, maxval=xmax)
|
||||
ax.set_ylim(0, upper)
|
||||
ax.axis('off')
|
||||
return None
|
||||
|
||||
def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs):
|
||||
y0 = low_ax.get_position().y0
|
||||
y1 = high_ax.get_position().y1
|
||||
transform = low_ax.transData + fig.transFigure.inverted()
|
||||
x0 = transform.transform((zoom_abs[0], 0))[0]
|
||||
x1 = transform.transform((zoom_abs[1], 0))[0]
|
||||
rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0,
|
||||
transform=fig.transFigure, **kwargs)
|
||||
fig.add_artist(rect)
|
||||
return None
|
||||
|
||||
|
||||
# GENERAL SETTINGS:
|
||||
data_paths = glob.glob('../data/processed/*.npz')
|
||||
stages = ['filt', 'env', 'log', 'inv',
|
||||
'conv', 'bi', 'feat']
|
||||
channel = 0
|
||||
target = 'Omocestus_rufipes'
|
||||
data_paths = glob.glob(f'../data/processed/{target}*.npz')
|
||||
stages = ['filt', 'env', 'log', 'inv', 'conv', 'bi', 'feat']
|
||||
save_name = '../figures/pathway_stages'
|
||||
|
||||
# PLOT SETTINGS:
|
||||
fig_kwargs = dict(
|
||||
figsize = np.array([16, 9]) * 0.75,
|
||||
layout = 'constrained',
|
||||
sharex = 'col',
|
||||
sharey = 'row'
|
||||
width=16 / 2.54 * 2,
|
||||
height=6 / 2.54 * 2,
|
||||
rheight=2 / 2.54 * 2,
|
||||
)
|
||||
zoom_rel = np.array([0.4, 0.6])
|
||||
colors = dict(
|
||||
filt='k',
|
||||
env='k',
|
||||
log='k',
|
||||
inv='k',
|
||||
conv='k',
|
||||
bi='k',
|
||||
feat='k'
|
||||
grid_kwargs = dict(
|
||||
wspace=0.15,
|
||||
hspace=0.3,
|
||||
left=0.1,
|
||||
right=0.95,
|
||||
bottom=0.1,
|
||||
top=0.95
|
||||
)
|
||||
linewidths = dict(
|
||||
colors = {s: c.item() for s, c in dict(np.load('../data/stage_colors.npz')).items()}
|
||||
lw_full = dict(
|
||||
filt=0.25,
|
||||
env=0.5,
|
||||
log=0.5,
|
||||
@@ -37,11 +126,24 @@ linewidths = dict(
|
||||
bi=1,
|
||||
feat=1
|
||||
)
|
||||
lw_zoom = dict(
|
||||
filt=0.25,
|
||||
env=0.75,
|
||||
log=0.75,
|
||||
inv=0.75,
|
||||
conv=0.5,
|
||||
bi=1,
|
||||
feat=0.75
|
||||
)
|
||||
zoom_rel = np.array([0.3, 0.4])
|
||||
zoom_kwargs = dict(
|
||||
color=(0.85, 0.85, 0.85),
|
||||
zorder=0,
|
||||
linewidth=0
|
||||
)
|
||||
|
||||
# EXECUTION:
|
||||
for data_path in data_paths:
|
||||
if 'Gomphocerippus' in data_path:
|
||||
continue
|
||||
print(f'Processing {data_path}')
|
||||
|
||||
# Load overall data:
|
||||
@@ -55,79 +157,80 @@ for data_path in data_paths:
|
||||
|
||||
|
||||
# PART I: PREPROCESSING STAGE
|
||||
fig, axes = plt.subplots(4, 2, **fig_kwargs)
|
||||
fig.supylabel('amplitude', fontsize=plt.rcParams['axes.labelsize'])
|
||||
fig.supxlabel('time [s]', fontsize=plt.rcParams['axes.labelsize'])
|
||||
fig, axes = prepare_fig(4, 2, **fig_kwargs, **grid_kwargs)
|
||||
super_xlabel('time [s]', fig, axes[0, 0], axes[0, -1])
|
||||
super_ylabel('amplitude', fig, axes[0, 0], axes[-1, 0])
|
||||
|
||||
# Bandpass-filtered signal:
|
||||
signal = data['filt'][:, channel]
|
||||
ax_full, ax_zoom = axes[0, :]
|
||||
c, lw = colors['filt'], linewidths['filt']
|
||||
ax_full.plot(t_full, signal, c=c, lw=lw)
|
||||
ax_zoom.plot(t_zoom, signal[zoom_mask], c=c, lw=lw)
|
||||
ax_full.set_ylim(signal.min(), signal.max())
|
||||
plot_line(ax_full, t_full, data['filt'], c=colors['filt'], lw=lw_full['filt'])
|
||||
plot_line(ax_zoom, t_zoom, data['filt'][zoom_mask], c=colors['filt'], lw=lw_zoom['filt'])
|
||||
hide_axis(ax_full, 'bottom')
|
||||
hide_axis(ax_zoom, 'bottom')
|
||||
hide_axis(ax_zoom, 'left')
|
||||
|
||||
# Signal envelope:
|
||||
signal = data['env'][:, channel]
|
||||
ax_full, ax_zoom = axes[1, :]
|
||||
c, lw = colors['env'], linewidths['env']
|
||||
ax_full.plot(t_full, signal, c=c, lw=lw)
|
||||
ax_zoom.plot(t_zoom, signal[zoom_mask], c=c, lw=lw)
|
||||
ax_full.set_ylim(0, signal.max())
|
||||
plot_line(ax_full, t_full, data['env'], ymin=0, c=colors['env'], lw=lw_full['env'])
|
||||
plot_line(ax_zoom, t_zoom, data['env'][zoom_mask], ymin=0, c=colors['env'], lw=lw_zoom['env'])
|
||||
hide_axis(ax_full, 'bottom')
|
||||
hide_axis(ax_zoom, 'bottom')
|
||||
hide_axis(ax_zoom, 'left')
|
||||
|
||||
# Logarithmic envelope:
|
||||
signal = data['log'][:, channel]
|
||||
ax_full, ax_zoom = axes[2, :]
|
||||
c, lw = colors['log'], linewidths['log']
|
||||
ax_full.plot(t_full, signal, c=c, lw=lw)
|
||||
ax_zoom.plot(t_zoom, signal[zoom_mask], c=c, lw=lw)
|
||||
ax_full.set_ylim(signal.min(), 0)
|
||||
plot_line(ax_full, t_full, data['log'], ymax=0, c=colors['log'], lw=lw_full['log'])
|
||||
plot_line(ax_zoom, t_zoom, data['log'][zoom_mask], ymax=0, c=colors['log'], lw=lw_zoom['log'])
|
||||
hide_axis(ax_full, 'bottom')
|
||||
hide_axis(ax_zoom, 'bottom')
|
||||
hide_axis(ax_zoom, 'left')
|
||||
|
||||
# Adapted envelope:
|
||||
signal = data['inv'][:, channel]
|
||||
ax_full, ax_zoom = axes[3, :]
|
||||
c, lw = colors['inv'], linewidths['inv']
|
||||
ax_full.plot(t_full, signal, c=c, lw=lw)
|
||||
ax_zoom.plot(t_zoom, signal[zoom_mask], c=c, lw=lw)
|
||||
ax_full.set_ylim(signal.min(), signal.max())
|
||||
plot_line(ax_full, t_full, data['inv'], c=colors['inv'], lw=lw_full['inv'])
|
||||
plot_line(ax_zoom, t_zoom, data['inv'][zoom_mask], c=colors['inv'], lw=lw_zoom['inv'])
|
||||
hide_axis(ax_zoom, 'left')
|
||||
|
||||
# Posthoc adjustments:
|
||||
ax_full.set_xlim(t_full[0], t_full[-1])
|
||||
ax_zoom.set_xlim(t_zoom[0], t_zoom[-1])
|
||||
indicate_zoom(fig, axes[0, 0], axes[-1, 0], zoom_abs, **zoom_kwargs)
|
||||
indicate_zoom(fig, axes[0, 1], axes[-1, 1], zoom_abs, **zoom_kwargs)
|
||||
if save_name is not None:
|
||||
fig.savefig(f'{save_name}_pre.pdf')
|
||||
|
||||
|
||||
# PART II: FEATURE EXTRACTION STAGE:
|
||||
fig, axes = plt.subplots(3, 2, **fig_kwargs)
|
||||
fig.supylabel('amplitude', fontsize=plt.rcParams['axes.labelsize'])
|
||||
fig.supxlabel('time [s]', fontsize=plt.rcParams['axes.labelsize'])
|
||||
fig, axes = prepare_fig(3, 2, **fig_kwargs, **grid_kwargs)
|
||||
super_xlabel('time [s]', fig, axes[0, 0], axes[0, -1])
|
||||
super_ylabel('amplitude', fig, axes[0, 0], axes[-1, 0])
|
||||
|
||||
# Convolutional filter responses:
|
||||
signal = data['conv'][:, :, channel]
|
||||
ax_full, ax_zoom = axes[0, :]
|
||||
c, lw = colors['conv'], linewidths['conv']
|
||||
ax_full.plot(t_full, signal, c=c, lw=lw)
|
||||
ax_zoom.plot(t_zoom, signal[zoom_mask, :], c=c, lw=lw)
|
||||
ax_full.set_ylim(signal.min(), signal.max())
|
||||
plot_line(ax_full, t_full, data['conv'], c=colors['conv'], lw=lw_full['conv'])
|
||||
plot_line(ax_zoom, t_zoom, data['conv'][zoom_mask, :], c=colors['conv'], lw=lw_zoom['conv'])
|
||||
hide_axis(ax_full, 'bottom')
|
||||
hide_axis(ax_zoom, 'bottom')
|
||||
hide_axis(ax_zoom, 'left')
|
||||
|
||||
# Binary responses:
|
||||
signal = data['bi'][:, :, channel]
|
||||
ax_full, ax_zoom = axes[1, :]
|
||||
c, lw = colors['bi'], linewidths['bi']
|
||||
ax_full.plot(t_full, signal, c=c, lw=lw)
|
||||
ax_zoom.plot(t_zoom, signal[zoom_mask, :], c=c, lw=lw)
|
||||
ax_full.set_ylim(signal.min(), signal.max())
|
||||
plot_barcode(ax_full, t_full, data['bi'], color=colors['bi'])
|
||||
plot_barcode(ax_zoom, t_zoom, data['bi'][zoom_mask, :], color=colors['bi'])
|
||||
|
||||
# Finalized features:
|
||||
signal = data['feat'][:, :, channel]
|
||||
ax_full, ax_zoom = axes[2, :]
|
||||
c, lw = colors['feat'], linewidths['feat']
|
||||
ax_full.plot(t_full, signal, c=c, lw=lw)
|
||||
ax_zoom.plot(t_zoom, signal[zoom_mask, :], c=c, lw=lw)
|
||||
ax_full.set_ylim(0, 1)
|
||||
plot_line(ax_full, t_full, data['feat'], ymin=0, ymax=1, c=colors['feat'], lw=lw_full['feat'])
|
||||
plot_line(ax_zoom, t_zoom, data['feat'][zoom_mask, :], ymin=0, ymax=1, c=colors['feat'], lw=lw_zoom['feat'])
|
||||
hide_axis(ax_zoom, 'left')
|
||||
|
||||
# Posthoc adjustments:
|
||||
ax_full.set_xlim(t_full[0], t_full[-1])
|
||||
ax_zoom.set_xlim(t_zoom[0], t_zoom[-1])
|
||||
indicate_zoom(fig, axes[0, 0], axes[-1, 0], zoom_abs, **zoom_kwargs)
|
||||
indicate_zoom(fig, axes[0, 1], axes[-1, 1], zoom_abs, **zoom_kwargs)
|
||||
if save_name is not None:
|
||||
fig.savefig(f'{save_name}_feat.pdf')
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
@@ -8,15 +8,18 @@ mpl.rcParams['font.style'] = 'normal'
|
||||
mpl.rcParams['font.family'] = 'serif'
|
||||
mpl.rcParams['mathtext.fontset'] = 'cm'
|
||||
mpl.rcParams['mathtext.default'] = 'regular'
|
||||
|
||||
# Font sizes:
|
||||
mpl.rcParams['font.size'] = 14
|
||||
mpl.rcParams['figure.titlesize'] = 15
|
||||
mpl.rcParams['figure.labelsize'] = 14
|
||||
mpl.rcParams['axes.labelsize'] = 14
|
||||
mpl.rcParams['axes.titlesize'] = 14
|
||||
mpl.rcParams['xtick.labelsize'] = 13
|
||||
mpl.rcParams['ytick.labelsize'] = 13
|
||||
mpl.rcParams['legend.fontsize'] = 14
|
||||
mpl.rcParams['legend.title_fontsize'] = 14
|
||||
|
||||
# Font weights:
|
||||
single_weight = ['normal', 'bold'][0]
|
||||
mpl.rcParams['font.weight'] = single_weight
|
||||
@@ -32,6 +35,7 @@ mpl.rcParams['figure.raise_window'] = True
|
||||
# mpl.rcParams['savefig.dpi'] = 500
|
||||
|
||||
# Axes parameters:
|
||||
mpl.rcParams['axes.linewidth'] = 1.5
|
||||
mpl.rcParams['axes.spines.top'] = False
|
||||
mpl.rcParams['axes.spines.right'] = False
|
||||
mpl.rcParams['axes.xmargin'] = 0
|
||||
@@ -39,10 +43,20 @@ mpl.rcParams['axes.ymargin'] = 0
|
||||
mpl.rcParams['axes.autolimit_mode'] = 'round_numbers'
|
||||
mpl.rcParams['axes.labelpad'] = 3
|
||||
|
||||
# Tick parameters:
|
||||
mpl.rcParams['xtick.major.pad'] = 5
|
||||
mpl.rcParams['ytick.major.pad'] = 5
|
||||
# Major tick parameters:
|
||||
mpl.rcParams['xtick.major.size'] = 5
|
||||
mpl.rcParams['xtick.major.width'] = 1.5
|
||||
mpl.rcParams['xtick.major.pad'] = 3.5
|
||||
mpl.rcParams['ytick.major.size'] = 5
|
||||
mpl.rcParams['ytick.major.width'] = 1.5
|
||||
mpl.rcParams['ytick.major.pad'] = 3.5
|
||||
|
||||
# Minor tick parameters:
|
||||
mpl.rcParams['xtick.minor.size'] = 2
|
||||
mpl.rcParams['xtick.minor.width'] = 1
|
||||
mpl.rcParams['ytick.minor.size'] = 2
|
||||
mpl.rcParams['ytick.minor.width'] = 1
|
||||
|
||||
# Legend parameters:
|
||||
mpl.rcParams['legend.frameon'] = False
|
||||
mpl.rcParams['legend.scatterpoints'] = 3
|
||||
mpl.rcParams['legend.scatterpoints'] = 3
|
||||
|
||||
61
python/save_snippet_data.py
Normal file
61
python/save_snippet_data.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import numpy as np
|
||||
from thunderhopper.filetools import search_files, crop_paths
|
||||
from thunderhopper.model import configuration, process_signal
|
||||
from thunderhopper.modeltools import load_data
|
||||
from IPython import embed
|
||||
|
||||
## SETTINGS:
|
||||
|
||||
# General:
|
||||
overwrite = True
|
||||
input_folder = '../data/raw/'
|
||||
output_folder = '../data/processed/'
|
||||
stages = ['raw', 'filt', 'env', 'log', 'inv', 'conv', 'bi', 'feat', 'norm']
|
||||
if True:
|
||||
# Overwrites edited:
|
||||
stages.append('songs')
|
||||
|
||||
# Interactivity:
|
||||
reload_saved = False
|
||||
gui = False
|
||||
|
||||
# Processing:
|
||||
env_rate = 44100.0
|
||||
feat_rate = 44100.0
|
||||
sigmas = [0.001, 0.002, 0.004, 0.008, 0.016, 0.032]
|
||||
types = [1, -1, 2, -2, 3, -3, 4, -4, 5, -5,
|
||||
6, -6, 7, -7, 8, -8, 9, -9, 10, -10]
|
||||
config = configuration(env_rate, feat_rate, types=types, sigmas=sigmas)
|
||||
config.update({
|
||||
'channel': 0,
|
||||
'rate_ratio': None,
|
||||
'env_fcut': 250,
|
||||
'inv_fcut': 5,
|
||||
'feat_thresh': np.load('../data/kernel_thresholds.npy') * 0.1,
|
||||
'feat_fcut': 0.5,
|
||||
'label_channels': 0,
|
||||
'label_thresh': 0.5,
|
||||
})
|
||||
|
||||
## PREPARATION:
|
||||
|
||||
# Fetch WAV recording files:
|
||||
input_paths = search_files(ext='wav', dir=input_folder)
|
||||
path_names = crop_paths(input_paths)
|
||||
|
||||
# PROCESSING:
|
||||
|
||||
# Run processing pipeline:
|
||||
for path, name in zip(input_paths, path_names):
|
||||
print('Processing:', name)
|
||||
|
||||
# Fetch and store representations:
|
||||
save = None if output_folder is None else output_folder + f'{name}.npz'
|
||||
process_signal(config, stages, path, save=save,
|
||||
label_edit=gui, overwrite=overwrite)
|
||||
|
||||
# Cross-control:
|
||||
if reload_saved:
|
||||
data, params = load_data(save, stages, ['songs', 'noise'])
|
||||
embed()
|
||||
print('Done.')
|
||||
195
python/save_stage_colors.py
Normal file
195
python/save_stage_colors.py
Normal file
@@ -0,0 +1,195 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from tkinter.colorchooser import askcolor
|
||||
from IPython import embed
|
||||
|
||||
|
||||
def is_hex_color(color):
|
||||
if isinstance(color, str) and color.startswith('#') and len(color) == 7:
|
||||
try:
|
||||
int(color[1:], 16)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
return False
|
||||
|
||||
|
||||
def is_rgb_color(color):
|
||||
if isinstance(color, (tuple, list)) and len(color) == 3:
|
||||
return all(isinstance(c, int) and 0 <= c <= 255 for c in color)
|
||||
return False
|
||||
|
||||
|
||||
def hex_to_rgb(color):
|
||||
color = color.lstrip('#')
|
||||
return tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
|
||||
|
||||
|
||||
def rgb_to_hex(color):
|
||||
return '#{:02x}{:02x}{:02x}'.format(*color)
|
||||
|
||||
|
||||
def whiten_color(color, factors):
|
||||
is_hex = is_hex_color(color)
|
||||
if is_hex:
|
||||
color = hex_to_rgb(color)
|
||||
elif not is_rgb_color(color):
|
||||
raise ValueError('Color format must be hex string or RGB tuple.')
|
||||
|
||||
whitened = tuple(min(255, int(c + (255 - c) * f)) for c, f in zip(color, factors))
|
||||
return rgb_to_hex(whitened) if is_hex else whitened
|
||||
|
||||
|
||||
def color_selector(n=5, colors=None, save=None, labels=None, hex=True,
|
||||
back_color='w', edge_color='k', edge_width=2):
|
||||
|
||||
|
||||
def pick_color(color='#ffffff', hex=True):
|
||||
color = askcolor(initialcolor=color)[hex]
|
||||
return color
|
||||
|
||||
|
||||
def update_color(artists, color):
|
||||
for i, artist in enumerate(artists):
|
||||
if isinstance(artist, plt.Rectangle):
|
||||
# Update patch colors:
|
||||
artist.set_fc(color)
|
||||
if i in [0, 1]:
|
||||
artist.set_ec(color)
|
||||
elif isinstance(artist, plt.Line2D):
|
||||
# Update marker colors:
|
||||
artist.set_mfc(color)
|
||||
if i in [0, 1]:
|
||||
artist.set_mec(color)
|
||||
fig.canvas.draw()
|
||||
return None
|
||||
|
||||
|
||||
# Prepare colors:
|
||||
if colors is None:
|
||||
colors = ['#ffffff' for _ in range(n)]
|
||||
start_colors = colors.copy()
|
||||
n = len(colors)
|
||||
plt.ion()
|
||||
|
||||
# Prepare graphical interface:
|
||||
fig, ax = plt.subplots(figsize=(10, 2), layout='constrained')
|
||||
ax.set_facecolor(back_color)
|
||||
ax.set_xlim(0, n)
|
||||
ax.set_ylim(0, 4)
|
||||
ax.axis('off')
|
||||
|
||||
# Add different color indicators:
|
||||
artist_groups, group_inds = {}, {}
|
||||
for i, color in enumerate(colors):
|
||||
|
||||
# Color on background, uniform edge:
|
||||
p1 = ax.add_patch(plt.Rectangle((i, 0), 1, 1, fc=color, ec=color))
|
||||
l1 = ax.plot(i + 0.5, 1.5, marker='o', ms=20, mfc=color, mec=color)[0]
|
||||
|
||||
# Color on background, black edge:
|
||||
p2 = ax.add_patch(plt.Rectangle((i, 3), 1, 1, fc=color, ec=edge_color, lw=edge_width))
|
||||
l2 = ax.plot(i + 0.5, 2.5, marker='o', ms=20, mfc=color, mec=edge_color, mew=edge_width)[0]
|
||||
|
||||
# Update artist mappings:
|
||||
handles = [p1, l1, l2, p2]
|
||||
artist_groups[i] = handles
|
||||
for handle in handles:
|
||||
handle.set_picker(True)
|
||||
group_inds[handle] = i
|
||||
|
||||
# Initialize memory variables for avoiding assignments:
|
||||
swap_groups, swap_colors = [None, None], [None, None]
|
||||
focus_group, focus_color = [None], [None, None]
|
||||
|
||||
|
||||
# Interactivity:
|
||||
def on_key(event):
|
||||
|
||||
# Abort and retry:
|
||||
if event.key == 'q':
|
||||
print('\nRestarting with original settings...')
|
||||
plt.close(fig)
|
||||
return color_selector(n, start_colors, save, labels, hex,
|
||||
back_color, edge_color, edge_width)
|
||||
|
||||
# Exit without saving:
|
||||
elif event.key == 'escape':
|
||||
print('\nExiting without saving...')
|
||||
plt.close(fig)
|
||||
return None
|
||||
|
||||
# Accept and save colors to file:
|
||||
elif event.key == 'enter' and save is not None:
|
||||
if labels is not None:
|
||||
# Convert to file dictionary and write to npz file:
|
||||
data = {str(label): col for label, col in zip(labels, colors)}
|
||||
np.savez(f'{save}.npz', **data)
|
||||
else:
|
||||
# Write array (n, 3) or (n,) to npy file:
|
||||
np.save(f'{save}.npy', np.array(colors))
|
||||
print(f'\nSaved {n} colors to file: {save}')
|
||||
plt.close(fig)
|
||||
return None
|
||||
|
||||
|
||||
def on_pick(event):
|
||||
|
||||
# Pick color for chosen group:
|
||||
if event.mouseevent.button == 1:
|
||||
# Update memory variables:
|
||||
focus_group[0] = group_inds[event.artist]
|
||||
focus_color[0] = colors[focus_group[0]]
|
||||
|
||||
# Trigger color picker dialogue:
|
||||
focus_color[1] = pick_color(focus_color[0], hex)
|
||||
if focus_color[1] is not None:
|
||||
# Apply, update, and report only if dialogue was not cancelled:
|
||||
update_color(artist_groups[focus_group[0]], focus_color[1])
|
||||
colors[focus_group[0]] = focus_color[1]
|
||||
print(f'\nUpdated color {focus_group[0] + 1} / {n}: {focus_color[1]}')
|
||||
|
||||
# Select 1st color to swap:
|
||||
elif event.mouseevent.button == 3:
|
||||
# Update memory variables and report:
|
||||
swap_groups[0] = group_inds[event.artist]
|
||||
swap_colors[0] = colors[swap_groups[0]]
|
||||
print(f'\nSelected 1st swap color: {swap_groups[0] + 1} / {n}')
|
||||
|
||||
# Select 2nd color to swap and execute:
|
||||
elif swap_groups[0] is not None and event.mouseevent.button == 2:
|
||||
# Update memory variables and report:
|
||||
swap_groups[1] = group_inds[event.artist]
|
||||
swap_colors[1] = colors[swap_groups[1]]
|
||||
print(f'Selected 2nd swap color: {swap_groups[1] + 1} / {n}')
|
||||
print(f'Swapping colors: {swap_groups[0] + 1} <-> {swap_groups[1] + 1}')
|
||||
|
||||
# Swap colors:
|
||||
for i in [0, 1]:
|
||||
# Apply to artist group and update global color list:
|
||||
update_color(artist_groups[swap_groups[i]], swap_colors[1 - i])
|
||||
colors[swap_groups[i]] = swap_colors[1 - i]
|
||||
|
||||
# Reset memory variables:
|
||||
swap_groups[0], swap_groups[1] = None, None
|
||||
swap_colors[0], swap_colors[1] = None, None
|
||||
|
||||
# Establish interactivity:
|
||||
plt.connect('key_press_event', on_key)
|
||||
plt.connect('pick_event', on_pick)
|
||||
plt.ioff()
|
||||
return colors
|
||||
|
||||
|
||||
# Settings:
|
||||
stages = ['filt', 'env', 'log', 'inv', 'conv', 'bi', 'feat']
|
||||
file_name = None#'../data/stage_colors'
|
||||
colors = None
|
||||
if True:
|
||||
colors = dict(np.load('../data/stage_colors.npz'))
|
||||
colors = [colors[stage].item() for stage in colors.keys()]
|
||||
|
||||
# Execution:
|
||||
colors = color_selector(len(stages), colors, save=file_name, labels=stages, hex=True)
|
||||
plt.show()
|
||||
embed()
|
||||
Reference in New Issue
Block a user