diff --git a/figures/pathway_stages_feat.pdf b/figures/pathway_stages_feat.pdf index e9c1090..24a39b9 100644 Binary files a/figures/pathway_stages_feat.pdf and b/figures/pathway_stages_feat.pdf differ diff --git a/figures/pathway_stages_pre.pdf b/figures/pathway_stages_pre.pdf index 328353c..36ce3c1 100644 Binary files a/figures/pathway_stages_pre.pdf and b/figures/pathway_stages_pre.pdf differ diff --git a/python/color_functions.py b/python/color_functions.py new file mode 100644 index 0000000..1050886 --- /dev/null +++ b/python/color_functions.py @@ -0,0 +1,249 @@ +import numpy as np +import matplotlib.pyplot as plt +from tkinter.colorchooser import askcolor +from IPython import embed + +# FALL-THROUGH MODIFIERS: + +def expand_hex_color(color): + if len(color) in [4, 5]: + return '#' + ''.join([c * 2 for c in color[1:]]) + elif len(color) in [7, 9]: + return color + raise ValueError(f'Unexpected digit count for hex color string: {color}') + + +def norm_rgb_color(color): + return tuple(c / 255 if isinstance(c, int) else c for c in color) + + +def unnorm_rgb_color(color): + return tuple(int(c * 255) if isinstance(c, float) else c for c in color) + +# FORMAT VALIDATION: + +def is_hex_color(color, alpha=None, shorthand=None): + long = {True: [9], False: [7], None: [7, 9]}[alpha] + short = {True: [5], False: [4], None: [4, 5]}[alpha] + valid = {True: short, False: long, None: long + short}[shorthand] + if isinstance(color, str) and color.startswith('#') and len(color) in valid: + if len(color) in short: + color = expand_hex_color(color) + try: + int(color[1:], 16) + return True + except ValueError: + return False + return False + + +def is_rgb_color(color, alpha=None, norm=True): + valid = {True: [4], False: [3], None: [3, 4]}[alpha] + if isinstance(color, (tuple, list, np.ndarray)) and len(color) in valid: + form = {True: float, False: (int, np.integer), None: (int, np.integer, float)}[norm] + cap = 1 if norm else 255 + return all(isinstance(c, form) and 0 <= c <= cap for c in color) + return False + +# FORMAT CONVERSION: + +def hex_to_rgb(color): + if not is_hex_color(color, alpha=None, shorthand=None): + raise ValueError(f'Invalid hex color format: {color}') + if len(color) in [4, 5]: + color = expand_hex_color(color) + color = color.lstrip('#') + return tuple(int(color[i:i+2], 16) for i in range(0, len(color), 2)) + + +def rgb_to_hex(color): + if not is_rgb_color(color, alpha=None, norm=None): + raise ValueError(f'Invalid RGB/RGBA color format: {color}') + color = unnorm_rgb_color(color) + return '#' + ''.join('{:02x}'.format(c) for c in color) + +# STORAGE AND RETRIEVAL: + +def save_colors(colors, path, labels=None): + if labels is not None: + data = {str(label): col for label, col in zip(labels, colors)} + np.savez(path, **data) + else: + np.save(path, np.array(colors)) + return None + + +def load_colors(path): + if path.endswith('.npy'): + return np.load(path) + elif path.endswith('.npz'): + colors = dict(np.load(path)) + return {k: (c.item() if c.size == 1 else c) for k, c in colors.items()} + raise ValueError(f'Expected .npy or .npz file extension: {path}') + +# ADVANCED FUNCTIONALITY: + +def shade_colors(color, factors, norm=True): + is_hex = is_hex_color(color, alpha=False) + if is_hex: + color = hex_to_rgb(color) + norm = False + elif not is_rgb_color(color, alpha=False, norm=norm): + msg = f'Color must be 6-digit hex string or RGB (not RGBA) tuple: {color}' + raise ValueError(msg) + + color = np.array(color) + light_shade = 1 if norm else 255 + dark_shade = 0 + + colors = [] + for factor in (factors if np.ndim(factors) > 0 else (factors,)): + shade_color = light_shade + if factor < 0: + factor = abs(factor) + shade_color = dark_shade + shaded = color * (1 - factor) + factor * shade_color + if not norm: + shaded = np.floor(shaded).astype(int) + colors.append(shaded) + + if is_hex: + colors = [rgb_to_hex(col) for col in colors] + return colors if len(factors) > 1 else colors[0] + + +def pick_color(color='#ffffff', hex=True): + color = askcolor(initialcolor=color)[hex] + return color + + +def color_selector(n=5, colors=None, save=None, labels=None, hex=True, + back_color='w', edge_color='k', edge_width=2): + + + def update_color(artists, color): + for i, artist in enumerate(artists): + if isinstance(artist, plt.Rectangle): + # Update patch colors: + artist.set_fc(color) + if i in [0, 1]: + artist.set_ec(color) + elif isinstance(artist, plt.Line2D): + # Update marker colors: + artist.set_mfc(color) + if i in [0, 1]: + artist.set_mec(color) + fig.canvas.draw() + return None + + + # Prepare colors: + if colors is None: + colors = ['#ffffff' for _ in range(n)] + elif isinstance(colors, dict): + colors = list(colors.values()) + start_colors = colors.copy() + n = len(colors) + plt.ion() + + # Prepare graphical interface: + fig, ax = plt.subplots(figsize=(10, 2), layout='constrained') + ax.set_facecolor(back_color) + ax.set_xlim(0, n) + ax.set_ylim(0, 4) + ax.axis('off') + + # Add different color indicators: + artist_groups, group_inds = {}, {} + for i, color in enumerate(colors): + + # Color on background, uniform edge: + p1 = ax.add_patch(plt.Rectangle((i, 0), 1, 1, fc=color, ec=color)) + l1 = ax.plot(i + 0.5, 1.5, marker='o', ms=20, mfc=color, mec=color)[0] + + # Color on background, black edge: + p2 = ax.add_patch(plt.Rectangle((i, 3), 1, 1, fc=color, ec=edge_color, lw=edge_width)) + l2 = ax.plot(i + 0.5, 2.5, marker='o', ms=20, mfc=color, mec=edge_color, mew=edge_width)[0] + + # Update artist mappings: + handles = [p1, l1, l2, p2] + artist_groups[i] = handles + for handle in handles: + handle.set_picker(True) + group_inds[handle] = i + + # Initialize memory variables for avoiding assignments: + swap_groups, swap_colors = [None, None], [None, None] + focus_group, focus_color = [None], [None, None] + + + # Interactivity: + def on_key(event): + + # Abort and retry: + if event.key == 'q': + print('\nRestarting with original settings...') + plt.close(fig) + return color_selector(n, start_colors, save, labels, hex, + back_color, edge_color, edge_width) + + # Exit without saving: + elif event.key == 'escape': + print('\nExiting without saving...') + plt.close(fig) + return None + + # Accept and save colors to file: + elif event.key == 'enter' and save is not None: + save_colors(colors, save, labels) + print(f'\nSaved {n} colors to file: {save}') + plt.close(fig) + return None + + + def on_pick(event): + + # Pick color for chosen group: + if event.mouseevent.button == 1: + # Update memory variables: + focus_group[0] = group_inds[event.artist] + focus_color[0] = colors[focus_group[0]] + + # Trigger color picker dialogue: + focus_color[1] = pick_color(focus_color[0], hex) + if focus_color[1] is not None: + # Apply, update, and report only if dialogue was not cancelled: + update_color(artist_groups[focus_group[0]], focus_color[1]) + colors[focus_group[0]] = focus_color[1] + print(f'\nUpdated color {focus_group[0] + 1} / {n}: {focus_color[1]}') + + # Select 1st color to swap: + elif event.mouseevent.button == 3: + # Update memory variables and report: + swap_groups[0] = group_inds[event.artist] + swap_colors[0] = colors[swap_groups[0]] + print(f'\nSelected 1st swap color: {swap_groups[0] + 1} / {n}') + + # Select 2nd color to swap and execute: + elif swap_groups[0] is not None and event.mouseevent.button == 2: + # Update memory variables and report: + swap_groups[1] = group_inds[event.artist] + swap_colors[1] = colors[swap_groups[1]] + print(f'Selected 2nd swap color: {swap_groups[1] + 1} / {n}') + print(f'Swapping colors: {swap_groups[0] + 1} <-> {swap_groups[1] + 1}') + + # Swap colors: + for i in [0, 1]: + # Apply to artist group and update global color list: + update_color(artist_groups[swap_groups[i]], swap_colors[1 - i]) + colors[swap_groups[i]] = swap_colors[1 - i] + + # Reset memory variables: + swap_groups[0], swap_groups[1] = None, None + swap_colors[0], swap_colors[1] = None, None + + # Establish interactivity: + plt.connect('key_press_event', on_key) + plt.connect('pick_event', on_pick) + plt.ioff() + return colors diff --git a/python/fig_pathway_stages.py b/python/fig_pathway_stages.py index bf1ebef..cfb03de 100644 --- a/python/fig_pathway_stages.py +++ b/python/fig_pathway_stages.py @@ -4,6 +4,7 @@ import numpy as np import matplotlib.pyplot as plt from itertools import product from thunderhopper.modeltools import load_data +from color_functions import load_colors, shade_colors from IPython import embed def prepare_fig(nrows, ncols, width=8, height=None, rheight=2, @@ -63,26 +64,27 @@ def hide_axis(ax, side='bottom'): def plot_line(ax, time, signal, ymin=None, ymax=None, xmin=None, xmax=None, xpad=None, ypad=0.05, **kwargs): - ax.plot(time, signal, **kwargs) + handles = ax.plot(time, signal, **kwargs) xlimits(ax, time, minval=xmin, maxval=xmax, pad=xpad) ylimits(ax, signal, minval=ymin, maxval=ymax, pad=ypad) - return None + return handles -def plot_barcode(ax, time, binary, offset=0.1, xmin=None, xmax=None, **kwargs): +def plot_barcode(ax, time, binary, offset=0.5, xmin=None, xmax=None, **kwargs): if xmin is None: xmin = time[0] if xmax is None: xmax = time[-1] - lower, upper = 0, 1 + lower, upper, handles = 0, 1, [] for i in range(binary.shape[1]): - ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs) + h = ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs) + handles.append(h) if i < binary.shape[1] - 1: lower += offset + 1 upper += offset + 1 xlimits(ax, time, minval=xmin, maxval=xmax) ax.set_ylim(0, upper) ax.axis('off') - return None + return handles def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs): y0 = low_ax.get_position().y0 @@ -95,6 +97,10 @@ def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs): fig.add_artist(rect) return None +def assign_colors(handles, types, colors): + for handle, type_id in zip(handles, types): + handle.set_color(colors[str(int(type_id))]) + # GENERAL SETTINGS: target = 'Omocestus_rufipes' @@ -116,7 +122,7 @@ grid_kwargs = dict( bottom=0.1, top=0.95 ) -colors = {s: c.item() for s, c in dict(np.load('../data/stage_colors.npz')).items()} +colors = load_colors('../data/stage_colors.npz') lw_full = dict( filt=0.25, env=0.5, @@ -141,6 +147,16 @@ zoom_kwargs = dict( zorder=0, linewidth=0 ) +# kernels = np.array([ +# [-2, 0.016], +# [3, 0.032] +# ]) +kern_types = np.array([1, -1, 2, -2, 3, -3, 4, -4]) +kern_sigmas = np.array([0.008, 0.016, 0.032]) +kernels = np.array([[k, s] for k in kern_types for s in kern_sigmas]) +conv_colors = load_colors('../data/conv_colors.npz') +bi_colors = load_colors('../data/bi_colors.npz') +feat_colors = load_colors('../data/feat_colors.npz') # EXECUTION: for data_path in data_paths: @@ -150,6 +166,11 @@ for data_path in data_paths: data, config = load_data(data_path, stages) t_full = np.arange(data['filt'].shape[0]) / config['rate'] + # Select kernel subset: + kern_inds = [np.nonzero((config['k_specs'] == k).all(1))[0][0] for k in kernels] + kern_inds = np.array(kern_inds) + kernel_specs = config['k_specs'][kern_inds] + # Establish zoom frame: zoom_abs = zoom_rel * t_full[-1] zoom_mask = (t_full >= zoom_abs[0]) & (t_full <= zoom_abs[1]) @@ -207,21 +228,30 @@ for data_path in data_paths: # Convolutional filter responses: ax_full, ax_zoom = axes[0, :] - plot_line(ax_full, t_full, data['conv'], c=colors['conv'], lw=lw_full['conv']) - plot_line(ax_zoom, t_zoom, data['conv'][zoom_mask, :], c=colors['conv'], lw=lw_zoom['conv']) + signal = data['conv'][:, kern_inds] + handles = plot_line(ax_full, t_full, signal, lw=lw_full['conv']) + assign_colors(handles, kernel_specs[:, 0], conv_colors) + handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], lw=lw_zoom['conv']) + assign_colors(handles, kernel_specs[:, 0], conv_colors) hide_axis(ax_full, 'bottom') hide_axis(ax_zoom, 'bottom') hide_axis(ax_zoom, 'left') # Binary responses: ax_full, ax_zoom = axes[1, :] - plot_barcode(ax_full, t_full, data['bi'], color=colors['bi']) - plot_barcode(ax_zoom, t_zoom, data['bi'][zoom_mask, :], color=colors['bi']) + signal = data['bi'][:, kern_inds] + handles = plot_barcode(ax_full, t_full, signal) + assign_colors(handles, kernel_specs[:, 0], bi_colors) + handles = plot_barcode(ax_zoom, t_zoom, signal[zoom_mask, :]) + assign_colors(handles, kernel_specs[:, 0], bi_colors) # Finalized features: ax_full, ax_zoom = axes[2, :] - plot_line(ax_full, t_full, data['feat'], ymin=0, ymax=1, c=colors['feat'], lw=lw_full['feat']) - plot_line(ax_zoom, t_zoom, data['feat'][zoom_mask, :], ymin=0, ymax=1, c=colors['feat'], lw=lw_zoom['feat']) + signal = data['feat'][:, kern_inds] + handles = plot_line(ax_full, t_full, signal, ymin=0, ymax=1, c=colors['feat'], lw=lw_full['feat']) + assign_colors(handles, kernel_specs[:, 0], feat_colors) + handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], ymin=0, ymax=1, c=colors['feat'], lw=lw_zoom['feat']) + assign_colors(handles, kernel_specs[:, 0], feat_colors) hide_axis(ax_zoom, 'left') # Posthoc adjustments: diff --git a/python/save_kernel_colors.py b/python/save_kernel_colors.py new file mode 100644 index 0000000..8a642e7 --- /dev/null +++ b/python/save_kernel_colors.py @@ -0,0 +1,16 @@ +import numpy as np +from color_functions import load_colors, shade_colors + +# Settings: +stages = ['conv', 'bi', 'feat'] +kern_types = np.array([1, -1, 2, -2, 3, -3, 4, -4]) +shade_factors = np.linspace(-0.5, 0.5, kern_types.size) + +# Main colors: +stage_colors = load_colors('../data/stage_colors.npz') + +# Execution: +for stage in stages: + colors = shade_colors(stage_colors[stage], shade_factors) + colors = {str(k): c for k, c in zip(kern_types, colors)} + np.savez(f'../data/{stage}_colors.npz', **colors) diff --git a/python/save_stage_colors.py b/python/save_stage_colors.py index 5b20005..e75b9df 100644 --- a/python/save_stage_colors.py +++ b/python/save_stage_colors.py @@ -1,195 +1,15 @@ -import numpy as np import matplotlib.pyplot as plt -from tkinter.colorchooser import askcolor +from color_functions import load_colors, color_selector from IPython import embed - -def is_hex_color(color): - if isinstance(color, str) and color.startswith('#') and len(color) == 7: - try: - int(color[1:], 16) - return True - except ValueError: - return False - return False - - -def is_rgb_color(color): - if isinstance(color, (tuple, list)) and len(color) == 3: - return all(isinstance(c, int) and 0 <= c <= 255 for c in color) - return False - - -def hex_to_rgb(color): - color = color.lstrip('#') - return tuple(int(color[i:i+2], 16) for i in (0, 2, 4)) - - -def rgb_to_hex(color): - return '#{:02x}{:02x}{:02x}'.format(*color) - - -def whiten_color(color, factors): - is_hex = is_hex_color(color) - if is_hex: - color = hex_to_rgb(color) - elif not is_rgb_color(color): - raise ValueError('Color format must be hex string or RGB tuple.') - - whitened = tuple(min(255, int(c + (255 - c) * f)) for c, f in zip(color, factors)) - return rgb_to_hex(whitened) if is_hex else whitened - - -def color_selector(n=5, colors=None, save=None, labels=None, hex=True, - back_color='w', edge_color='k', edge_width=2): - - - def pick_color(color='#ffffff', hex=True): - color = askcolor(initialcolor=color)[hex] - return color - - - def update_color(artists, color): - for i, artist in enumerate(artists): - if isinstance(artist, plt.Rectangle): - # Update patch colors: - artist.set_fc(color) - if i in [0, 1]: - artist.set_ec(color) - elif isinstance(artist, plt.Line2D): - # Update marker colors: - artist.set_mfc(color) - if i in [0, 1]: - artist.set_mec(color) - fig.canvas.draw() - return None - - - # Prepare colors: - if colors is None: - colors = ['#ffffff' for _ in range(n)] - start_colors = colors.copy() - n = len(colors) - plt.ion() - - # Prepare graphical interface: - fig, ax = plt.subplots(figsize=(10, 2), layout='constrained') - ax.set_facecolor(back_color) - ax.set_xlim(0, n) - ax.set_ylim(0, 4) - ax.axis('off') - - # Add different color indicators: - artist_groups, group_inds = {}, {} - for i, color in enumerate(colors): - - # Color on background, uniform edge: - p1 = ax.add_patch(plt.Rectangle((i, 0), 1, 1, fc=color, ec=color)) - l1 = ax.plot(i + 0.5, 1.5, marker='o', ms=20, mfc=color, mec=color)[0] - - # Color on background, black edge: - p2 = ax.add_patch(plt.Rectangle((i, 3), 1, 1, fc=color, ec=edge_color, lw=edge_width)) - l2 = ax.plot(i + 0.5, 2.5, marker='o', ms=20, mfc=color, mec=edge_color, mew=edge_width)[0] - - # Update artist mappings: - handles = [p1, l1, l2, p2] - artist_groups[i] = handles - for handle in handles: - handle.set_picker(True) - group_inds[handle] = i - - # Initialize memory variables for avoiding assignments: - swap_groups, swap_colors = [None, None], [None, None] - focus_group, focus_color = [None], [None, None] - - - # Interactivity: - def on_key(event): - - # Abort and retry: - if event.key == 'q': - print('\nRestarting with original settings...') - plt.close(fig) - return color_selector(n, start_colors, save, labels, hex, - back_color, edge_color, edge_width) - - # Exit without saving: - elif event.key == 'escape': - print('\nExiting without saving...') - plt.close(fig) - return None - - # Accept and save colors to file: - elif event.key == 'enter' and save is not None: - if labels is not None: - # Convert to file dictionary and write to npz file: - data = {str(label): col for label, col in zip(labels, colors)} - np.savez(f'{save}.npz', **data) - else: - # Write array (n, 3) or (n,) to npy file: - np.save(f'{save}.npy', np.array(colors)) - print(f'\nSaved {n} colors to file: {save}') - plt.close(fig) - return None - - - def on_pick(event): - - # Pick color for chosen group: - if event.mouseevent.button == 1: - # Update memory variables: - focus_group[0] = group_inds[event.artist] - focus_color[0] = colors[focus_group[0]] - - # Trigger color picker dialogue: - focus_color[1] = pick_color(focus_color[0], hex) - if focus_color[1] is not None: - # Apply, update, and report only if dialogue was not cancelled: - update_color(artist_groups[focus_group[0]], focus_color[1]) - colors[focus_group[0]] = focus_color[1] - print(f'\nUpdated color {focus_group[0] + 1} / {n}: {focus_color[1]}') - - # Select 1st color to swap: - elif event.mouseevent.button == 3: - # Update memory variables and report: - swap_groups[0] = group_inds[event.artist] - swap_colors[0] = colors[swap_groups[0]] - print(f'\nSelected 1st swap color: {swap_groups[0] + 1} / {n}') - - # Select 2nd color to swap and execute: - elif swap_groups[0] is not None and event.mouseevent.button == 2: - # Update memory variables and report: - swap_groups[1] = group_inds[event.artist] - swap_colors[1] = colors[swap_groups[1]] - print(f'Selected 2nd swap color: {swap_groups[1] + 1} / {n}') - print(f'Swapping colors: {swap_groups[0] + 1} <-> {swap_groups[1] + 1}') - - # Swap colors: - for i in [0, 1]: - # Apply to artist group and update global color list: - update_color(artist_groups[swap_groups[i]], swap_colors[1 - i]) - colors[swap_groups[i]] = swap_colors[1 - i] - - # Reset memory variables: - swap_groups[0], swap_groups[1] = None, None - swap_colors[0], swap_colors[1] = None, None - - # Establish interactivity: - plt.connect('key_press_event', on_key) - plt.connect('pick_event', on_pick) - plt.ioff() - return colors - - # Settings: stages = ['filt', 'env', 'log', 'inv', 'conv', 'bi', 'feat'] file_name = None#'../data/stage_colors' colors = None if True: - colors = dict(np.load('../data/stage_colors.npz')) - colors = [colors[stage].item() for stage in colors.keys()] + colors = load_colors('../data/stage_colors.npz') # Execution: -colors = color_selector(len(stages), colors, save=file_name, labels=stages, hex=True) +colors = color_selector(len(stages), colors, file_name, labels=stages, hex=True) plt.show() embed()