Implemented color shading in 2nd methods figure.

This commit is contained in:
j-hartling 2026-02-18 16:04:26 +01:00
parent 49b2bdcfca
commit 652621f782
6 changed files with 311 additions and 196 deletions

Binary file not shown.

Binary file not shown.

249
python/color_functions.py Normal file
View File

@ -0,0 +1,249 @@
import numpy as np
import matplotlib.pyplot as plt
from tkinter.colorchooser import askcolor
from IPython import embed
# FALL-THROUGH MODIFIERS:
def expand_hex_color(color):
if len(color) in [4, 5]:
return '#' + ''.join([c * 2 for c in color[1:]])
elif len(color) in [7, 9]:
return color
raise ValueError(f'Unexpected digit count for hex color string: {color}')
def norm_rgb_color(color):
return tuple(c / 255 if isinstance(c, int) else c for c in color)
def unnorm_rgb_color(color):
return tuple(int(c * 255) if isinstance(c, float) else c for c in color)
# FORMAT VALIDATION:
def is_hex_color(color, alpha=None, shorthand=None):
long = {True: [9], False: [7], None: [7, 9]}[alpha]
short = {True: [5], False: [4], None: [4, 5]}[alpha]
valid = {True: short, False: long, None: long + short}[shorthand]
if isinstance(color, str) and color.startswith('#') and len(color) in valid:
if len(color) in short:
color = expand_hex_color(color)
try:
int(color[1:], 16)
return True
except ValueError:
return False
return False
def is_rgb_color(color, alpha=None, norm=True):
valid = {True: [4], False: [3], None: [3, 4]}[alpha]
if isinstance(color, (tuple, list, np.ndarray)) and len(color) in valid:
form = {True: float, False: (int, np.integer), None: (int, np.integer, float)}[norm]
cap = 1 if norm else 255
return all(isinstance(c, form) and 0 <= c <= cap for c in color)
return False
# FORMAT CONVERSION:
def hex_to_rgb(color):
if not is_hex_color(color, alpha=None, shorthand=None):
raise ValueError(f'Invalid hex color format: {color}')
if len(color) in [4, 5]:
color = expand_hex_color(color)
color = color.lstrip('#')
return tuple(int(color[i:i+2], 16) for i in range(0, len(color), 2))
def rgb_to_hex(color):
if not is_rgb_color(color, alpha=None, norm=None):
raise ValueError(f'Invalid RGB/RGBA color format: {color}')
color = unnorm_rgb_color(color)
return '#' + ''.join('{:02x}'.format(c) for c in color)
# STORAGE AND RETRIEVAL:
def save_colors(colors, path, labels=None):
if labels is not None:
data = {str(label): col for label, col in zip(labels, colors)}
np.savez(path, **data)
else:
np.save(path, np.array(colors))
return None
def load_colors(path):
if path.endswith('.npy'):
return np.load(path)
elif path.endswith('.npz'):
colors = dict(np.load(path))
return {k: (c.item() if c.size == 1 else c) for k, c in colors.items()}
raise ValueError(f'Expected .npy or .npz file extension: {path}')
# ADVANCED FUNCTIONALITY:
def shade_colors(color, factors, norm=True):
is_hex = is_hex_color(color, alpha=False)
if is_hex:
color = hex_to_rgb(color)
norm = False
elif not is_rgb_color(color, alpha=False, norm=norm):
msg = f'Color must be 6-digit hex string or RGB (not RGBA) tuple: {color}'
raise ValueError(msg)
color = np.array(color)
light_shade = 1 if norm else 255
dark_shade = 0
colors = []
for factor in (factors if np.ndim(factors) > 0 else (factors,)):
shade_color = light_shade
if factor < 0:
factor = abs(factor)
shade_color = dark_shade
shaded = color * (1 - factor) + factor * shade_color
if not norm:
shaded = np.floor(shaded).astype(int)
colors.append(shaded)
if is_hex:
colors = [rgb_to_hex(col) for col in colors]
return colors if len(factors) > 1 else colors[0]
def pick_color(color='#ffffff', hex=True):
color = askcolor(initialcolor=color)[hex]
return color
def color_selector(n=5, colors=None, save=None, labels=None, hex=True,
back_color='w', edge_color='k', edge_width=2):
def update_color(artists, color):
for i, artist in enumerate(artists):
if isinstance(artist, plt.Rectangle):
# Update patch colors:
artist.set_fc(color)
if i in [0, 1]:
artist.set_ec(color)
elif isinstance(artist, plt.Line2D):
# Update marker colors:
artist.set_mfc(color)
if i in [0, 1]:
artist.set_mec(color)
fig.canvas.draw()
return None
# Prepare colors:
if colors is None:
colors = ['#ffffff' for _ in range(n)]
elif isinstance(colors, dict):
colors = list(colors.values())
start_colors = colors.copy()
n = len(colors)
plt.ion()
# Prepare graphical interface:
fig, ax = plt.subplots(figsize=(10, 2), layout='constrained')
ax.set_facecolor(back_color)
ax.set_xlim(0, n)
ax.set_ylim(0, 4)
ax.axis('off')
# Add different color indicators:
artist_groups, group_inds = {}, {}
for i, color in enumerate(colors):
# Color on background, uniform edge:
p1 = ax.add_patch(plt.Rectangle((i, 0), 1, 1, fc=color, ec=color))
l1 = ax.plot(i + 0.5, 1.5, marker='o', ms=20, mfc=color, mec=color)[0]
# Color on background, black edge:
p2 = ax.add_patch(plt.Rectangle((i, 3), 1, 1, fc=color, ec=edge_color, lw=edge_width))
l2 = ax.plot(i + 0.5, 2.5, marker='o', ms=20, mfc=color, mec=edge_color, mew=edge_width)[0]
# Update artist mappings:
handles = [p1, l1, l2, p2]
artist_groups[i] = handles
for handle in handles:
handle.set_picker(True)
group_inds[handle] = i
# Initialize memory variables for avoiding assignments:
swap_groups, swap_colors = [None, None], [None, None]
focus_group, focus_color = [None], [None, None]
# Interactivity:
def on_key(event):
# Abort and retry:
if event.key == 'q':
print('\nRestarting with original settings...')
plt.close(fig)
return color_selector(n, start_colors, save, labels, hex,
back_color, edge_color, edge_width)
# Exit without saving:
elif event.key == 'escape':
print('\nExiting without saving...')
plt.close(fig)
return None
# Accept and save colors to file:
elif event.key == 'enter' and save is not None:
save_colors(colors, save, labels)
print(f'\nSaved {n} colors to file: {save}')
plt.close(fig)
return None
def on_pick(event):
# Pick color for chosen group:
if event.mouseevent.button == 1:
# Update memory variables:
focus_group[0] = group_inds[event.artist]
focus_color[0] = colors[focus_group[0]]
# Trigger color picker dialogue:
focus_color[1] = pick_color(focus_color[0], hex)
if focus_color[1] is not None:
# Apply, update, and report only if dialogue was not cancelled:
update_color(artist_groups[focus_group[0]], focus_color[1])
colors[focus_group[0]] = focus_color[1]
print(f'\nUpdated color {focus_group[0] + 1} / {n}: {focus_color[1]}')
# Select 1st color to swap:
elif event.mouseevent.button == 3:
# Update memory variables and report:
swap_groups[0] = group_inds[event.artist]
swap_colors[0] = colors[swap_groups[0]]
print(f'\nSelected 1st swap color: {swap_groups[0] + 1} / {n}')
# Select 2nd color to swap and execute:
elif swap_groups[0] is not None and event.mouseevent.button == 2:
# Update memory variables and report:
swap_groups[1] = group_inds[event.artist]
swap_colors[1] = colors[swap_groups[1]]
print(f'Selected 2nd swap color: {swap_groups[1] + 1} / {n}')
print(f'Swapping colors: {swap_groups[0] + 1} <-> {swap_groups[1] + 1}')
# Swap colors:
for i in [0, 1]:
# Apply to artist group and update global color list:
update_color(artist_groups[swap_groups[i]], swap_colors[1 - i])
colors[swap_groups[i]] = swap_colors[1 - i]
# Reset memory variables:
swap_groups[0], swap_groups[1] = None, None
swap_colors[0], swap_colors[1] = None, None
# Establish interactivity:
plt.connect('key_press_event', on_key)
plt.connect('pick_event', on_pick)
plt.ioff()
return colors

View File

@ -4,6 +4,7 @@ import numpy as np
import matplotlib.pyplot as plt
from itertools import product
from thunderhopper.modeltools import load_data
from color_functions import load_colors, shade_colors
from IPython import embed
def prepare_fig(nrows, ncols, width=8, height=None, rheight=2,
@ -63,26 +64,27 @@ def hide_axis(ax, side='bottom'):
def plot_line(ax, time, signal, ymin=None, ymax=None, xmin=None, xmax=None,
xpad=None, ypad=0.05, **kwargs):
ax.plot(time, signal, **kwargs)
handles = ax.plot(time, signal, **kwargs)
xlimits(ax, time, minval=xmin, maxval=xmax, pad=xpad)
ylimits(ax, signal, minval=ymin, maxval=ymax, pad=ypad)
return None
return handles
def plot_barcode(ax, time, binary, offset=0.1, xmin=None, xmax=None, **kwargs):
def plot_barcode(ax, time, binary, offset=0.5, xmin=None, xmax=None, **kwargs):
if xmin is None:
xmin = time[0]
if xmax is None:
xmax = time[-1]
lower, upper = 0, 1
lower, upper, handles = 0, 1, []
for i in range(binary.shape[1]):
ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs)
h = ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs)
handles.append(h)
if i < binary.shape[1] - 1:
lower += offset + 1
upper += offset + 1
xlimits(ax, time, minval=xmin, maxval=xmax)
ax.set_ylim(0, upper)
ax.axis('off')
return None
return handles
def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs):
y0 = low_ax.get_position().y0
@ -95,6 +97,10 @@ def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs):
fig.add_artist(rect)
return None
def assign_colors(handles, types, colors):
for handle, type_id in zip(handles, types):
handle.set_color(colors[str(int(type_id))])
# GENERAL SETTINGS:
target = 'Omocestus_rufipes'
@ -116,7 +122,7 @@ grid_kwargs = dict(
bottom=0.1,
top=0.95
)
colors = {s: c.item() for s, c in dict(np.load('../data/stage_colors.npz')).items()}
colors = load_colors('../data/stage_colors.npz')
lw_full = dict(
filt=0.25,
env=0.5,
@ -141,6 +147,16 @@ zoom_kwargs = dict(
zorder=0,
linewidth=0
)
# kernels = np.array([
# [-2, 0.016],
# [3, 0.032]
# ])
kern_types = np.array([1, -1, 2, -2, 3, -3, 4, -4])
kern_sigmas = np.array([0.008, 0.016, 0.032])
kernels = np.array([[k, s] for k in kern_types for s in kern_sigmas])
conv_colors = load_colors('../data/conv_colors.npz')
bi_colors = load_colors('../data/bi_colors.npz')
feat_colors = load_colors('../data/feat_colors.npz')
# EXECUTION:
for data_path in data_paths:
@ -150,6 +166,11 @@ for data_path in data_paths:
data, config = load_data(data_path, stages)
t_full = np.arange(data['filt'].shape[0]) / config['rate']
# Select kernel subset:
kern_inds = [np.nonzero((config['k_specs'] == k).all(1))[0][0] for k in kernels]
kern_inds = np.array(kern_inds)
kernel_specs = config['k_specs'][kern_inds]
# Establish zoom frame:
zoom_abs = zoom_rel * t_full[-1]
zoom_mask = (t_full >= zoom_abs[0]) & (t_full <= zoom_abs[1])
@ -207,21 +228,30 @@ for data_path in data_paths:
# Convolutional filter responses:
ax_full, ax_zoom = axes[0, :]
plot_line(ax_full, t_full, data['conv'], c=colors['conv'], lw=lw_full['conv'])
plot_line(ax_zoom, t_zoom, data['conv'][zoom_mask, :], c=colors['conv'], lw=lw_zoom['conv'])
signal = data['conv'][:, kern_inds]
handles = plot_line(ax_full, t_full, signal, lw=lw_full['conv'])
assign_colors(handles, kernel_specs[:, 0], conv_colors)
handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], lw=lw_zoom['conv'])
assign_colors(handles, kernel_specs[:, 0], conv_colors)
hide_axis(ax_full, 'bottom')
hide_axis(ax_zoom, 'bottom')
hide_axis(ax_zoom, 'left')
# Binary responses:
ax_full, ax_zoom = axes[1, :]
plot_barcode(ax_full, t_full, data['bi'], color=colors['bi'])
plot_barcode(ax_zoom, t_zoom, data['bi'][zoom_mask, :], color=colors['bi'])
signal = data['bi'][:, kern_inds]
handles = plot_barcode(ax_full, t_full, signal)
assign_colors(handles, kernel_specs[:, 0], bi_colors)
handles = plot_barcode(ax_zoom, t_zoom, signal[zoom_mask, :])
assign_colors(handles, kernel_specs[:, 0], bi_colors)
# Finalized features:
ax_full, ax_zoom = axes[2, :]
plot_line(ax_full, t_full, data['feat'], ymin=0, ymax=1, c=colors['feat'], lw=lw_full['feat'])
plot_line(ax_zoom, t_zoom, data['feat'][zoom_mask, :], ymin=0, ymax=1, c=colors['feat'], lw=lw_zoom['feat'])
signal = data['feat'][:, kern_inds]
handles = plot_line(ax_full, t_full, signal, ymin=0, ymax=1, c=colors['feat'], lw=lw_full['feat'])
assign_colors(handles, kernel_specs[:, 0], feat_colors)
handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], ymin=0, ymax=1, c=colors['feat'], lw=lw_zoom['feat'])
assign_colors(handles, kernel_specs[:, 0], feat_colors)
hide_axis(ax_zoom, 'left')
# Posthoc adjustments:

View File

@ -0,0 +1,16 @@
import numpy as np
from color_functions import load_colors, shade_colors
# Settings:
stages = ['conv', 'bi', 'feat']
kern_types = np.array([1, -1, 2, -2, 3, -3, 4, -4])
shade_factors = np.linspace(-0.5, 0.5, kern_types.size)
# Main colors:
stage_colors = load_colors('../data/stage_colors.npz')
# Execution:
for stage in stages:
colors = shade_colors(stage_colors[stage], shade_factors)
colors = {str(k): c for k, c in zip(kern_types, colors)}
np.savez(f'../data/{stage}_colors.npz', **colors)

View File

@ -1,195 +1,15 @@
import numpy as np
import matplotlib.pyplot as plt
from tkinter.colorchooser import askcolor
from color_functions import load_colors, color_selector
from IPython import embed
def is_hex_color(color):
if isinstance(color, str) and color.startswith('#') and len(color) == 7:
try:
int(color[1:], 16)
return True
except ValueError:
return False
return False
def is_rgb_color(color):
if isinstance(color, (tuple, list)) and len(color) == 3:
return all(isinstance(c, int) and 0 <= c <= 255 for c in color)
return False
def hex_to_rgb(color):
color = color.lstrip('#')
return tuple(int(color[i:i+2], 16) for i in (0, 2, 4))
def rgb_to_hex(color):
return '#{:02x}{:02x}{:02x}'.format(*color)
def whiten_color(color, factors):
is_hex = is_hex_color(color)
if is_hex:
color = hex_to_rgb(color)
elif not is_rgb_color(color):
raise ValueError('Color format must be hex string or RGB tuple.')
whitened = tuple(min(255, int(c + (255 - c) * f)) for c, f in zip(color, factors))
return rgb_to_hex(whitened) if is_hex else whitened
def color_selector(n=5, colors=None, save=None, labels=None, hex=True,
back_color='w', edge_color='k', edge_width=2):
def pick_color(color='#ffffff', hex=True):
color = askcolor(initialcolor=color)[hex]
return color
def update_color(artists, color):
for i, artist in enumerate(artists):
if isinstance(artist, plt.Rectangle):
# Update patch colors:
artist.set_fc(color)
if i in [0, 1]:
artist.set_ec(color)
elif isinstance(artist, plt.Line2D):
# Update marker colors:
artist.set_mfc(color)
if i in [0, 1]:
artist.set_mec(color)
fig.canvas.draw()
return None
# Prepare colors:
if colors is None:
colors = ['#ffffff' for _ in range(n)]
start_colors = colors.copy()
n = len(colors)
plt.ion()
# Prepare graphical interface:
fig, ax = plt.subplots(figsize=(10, 2), layout='constrained')
ax.set_facecolor(back_color)
ax.set_xlim(0, n)
ax.set_ylim(0, 4)
ax.axis('off')
# Add different color indicators:
artist_groups, group_inds = {}, {}
for i, color in enumerate(colors):
# Color on background, uniform edge:
p1 = ax.add_patch(plt.Rectangle((i, 0), 1, 1, fc=color, ec=color))
l1 = ax.plot(i + 0.5, 1.5, marker='o', ms=20, mfc=color, mec=color)[0]
# Color on background, black edge:
p2 = ax.add_patch(plt.Rectangle((i, 3), 1, 1, fc=color, ec=edge_color, lw=edge_width))
l2 = ax.plot(i + 0.5, 2.5, marker='o', ms=20, mfc=color, mec=edge_color, mew=edge_width)[0]
# Update artist mappings:
handles = [p1, l1, l2, p2]
artist_groups[i] = handles
for handle in handles:
handle.set_picker(True)
group_inds[handle] = i
# Initialize memory variables for avoiding assignments:
swap_groups, swap_colors = [None, None], [None, None]
focus_group, focus_color = [None], [None, None]
# Interactivity:
def on_key(event):
# Abort and retry:
if event.key == 'q':
print('\nRestarting with original settings...')
plt.close(fig)
return color_selector(n, start_colors, save, labels, hex,
back_color, edge_color, edge_width)
# Exit without saving:
elif event.key == 'escape':
print('\nExiting without saving...')
plt.close(fig)
return None
# Accept and save colors to file:
elif event.key == 'enter' and save is not None:
if labels is not None:
# Convert to file dictionary and write to npz file:
data = {str(label): col for label, col in zip(labels, colors)}
np.savez(f'{save}.npz', **data)
else:
# Write array (n, 3) or (n,) to npy file:
np.save(f'{save}.npy', np.array(colors))
print(f'\nSaved {n} colors to file: {save}')
plt.close(fig)
return None
def on_pick(event):
# Pick color for chosen group:
if event.mouseevent.button == 1:
# Update memory variables:
focus_group[0] = group_inds[event.artist]
focus_color[0] = colors[focus_group[0]]
# Trigger color picker dialogue:
focus_color[1] = pick_color(focus_color[0], hex)
if focus_color[1] is not None:
# Apply, update, and report only if dialogue was not cancelled:
update_color(artist_groups[focus_group[0]], focus_color[1])
colors[focus_group[0]] = focus_color[1]
print(f'\nUpdated color {focus_group[0] + 1} / {n}: {focus_color[1]}')
# Select 1st color to swap:
elif event.mouseevent.button == 3:
# Update memory variables and report:
swap_groups[0] = group_inds[event.artist]
swap_colors[0] = colors[swap_groups[0]]
print(f'\nSelected 1st swap color: {swap_groups[0] + 1} / {n}')
# Select 2nd color to swap and execute:
elif swap_groups[0] is not None and event.mouseevent.button == 2:
# Update memory variables and report:
swap_groups[1] = group_inds[event.artist]
swap_colors[1] = colors[swap_groups[1]]
print(f'Selected 2nd swap color: {swap_groups[1] + 1} / {n}')
print(f'Swapping colors: {swap_groups[0] + 1} <-> {swap_groups[1] + 1}')
# Swap colors:
for i in [0, 1]:
# Apply to artist group and update global color list:
update_color(artist_groups[swap_groups[i]], swap_colors[1 - i])
colors[swap_groups[i]] = swap_colors[1 - i]
# Reset memory variables:
swap_groups[0], swap_groups[1] = None, None
swap_colors[0], swap_colors[1] = None, None
# Establish interactivity:
plt.connect('key_press_event', on_key)
plt.connect('pick_event', on_pick)
plt.ioff()
return colors
# Settings:
stages = ['filt', 'env', 'log', 'inv', 'conv', 'bi', 'feat']
file_name = None#'../data/stage_colors'
colors = None
if True:
colors = dict(np.load('../data/stage_colors.npz'))
colors = [colors[stage].item() for stage in colors.keys()]
colors = load_colors('../data/stage_colors.npz')
# Execution:
colors = color_selector(len(stages), colors, save=file_name, labels=stages, hex=True)
colors = color_selector(len(stages), colors, file_name, labels=stages, hex=True)
plt.show()
embed()