diff --git a/python/fig_invariance_log-hp.py b/python/fig_invariance_log-hp.py new file mode 100644 index 0000000..ed42101 --- /dev/null +++ b/python/fig_invariance_log-hp.py @@ -0,0 +1,133 @@ +import numpy as np +import matplotlib.pyplot as plt +from itertools import product + +def prepare_fig(nrows, ncols, width=8, height=None, rheight=2, + left=0.01, right=0.95, bottom=0.01, top=0.95, + wspace=0.4, hspace=0.4): + if height is None: + height = rheight * nrows + fig = plt.figure(figsize=(width, height)) + grid = fig.add_gridspec(nrows=nrows, ncols=ncols, wspace=wspace, hspace=hspace, + left=left, right=right, top=top, bottom=bottom) + axes = np.zeros((nrows, ncols), dtype=object) + for i, j in product(range(nrows), range(ncols)): + axes[i, j] = fig.add_subplot(grid[i, j]) + axes[i, j].set_facecolor('none') + return fig, axes + +def xlimits(ax, time, minval=None, maxval=None, pad=0.05): + limits = [minval, maxval] + if minval is None: + limits[0] = time[0] + if maxval is None: + limits[1] = time[-1] + if pad is not None and minval is None: + limits[0] -= (limits[1] - limits[0]) * pad + if pad is not None and maxval is None: + limits[1] += (limits[1] - limits[0]) * pad + return ax.set_xlim(limits) + +def ylimits(ax, signal, minval=None, maxval=None, pad=0.05): + limits = [minval, maxval] + if minval is None: + limits[0] = signal.min() + if maxval is None: + limits[1] = signal.max() + if pad is not None and minval is None: + limits[0] -= (limits[1] - limits[0]) * pad + if pad is not None and maxval is None: + limits[1] += (limits[1] - limits[0]) * pad + return ax.set_ylim(limits) + +def ylabel(ax, label, x=-0.23, fontsize=20): + ax.set_ylabel(label, fontsize=fontsize, rotation=0, ha='left', va='center') + ax.yaxis.set_label_coords(x, 0.5) + return None + +def super_xlabel(label, fig, high_ax, low_ax, y=0.005, **kwargs): + x = (low_ax.get_position().x0 + high_ax.get_position().x1) / 2 + fig.supxlabel(label, x=x, y=y, **kwargs) + return None + +def super_ylabel(label, fig, high_ax, low_ax, x=0.005, **kwargs): + y = (low_ax.get_position().y0 + high_ax.get_position().y1) / 2 + fig.supylabel(label, x=x, y=y, **kwargs) + return None + +def hide_axis(ax, side='bottom'): + ax.spines[side].set_visible(False) + params = {side: False, 'label' + side: False} + ax.tick_params(axis='x' if side in ['top', 'bottom'] else 'y', + which='both', **params) + return None + +def plot_line(ax, time, signal, ymin=None, ymax=None, xmin=None, xmax=None, + xpad=None, ypad=0.05, yloc=None, **kwargs): + handles = ax.plot(time, signal, **kwargs) + xlimits(ax, time, minval=xmin, maxval=xmax, pad=xpad) + ylimits(ax, signal, minval=ymin, maxval=ymax, pad=ypad) + ax.yaxis.set_major_locator(plt.MultipleLocator(yloc)) + return handles + +def plot_barcode(ax, time, binary, offset=0.5, xmin=None, xmax=None, **kwargs): + if xmin is None: + xmin = time[0] + if xmax is None: + xmax = time[-1] + lower, upper, handles = 0, 1, [] + for i in range(binary.shape[1]): + h = ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs) + handles.append(h) + if i < binary.shape[1] - 1: + lower += offset + 1 + upper += offset + 1 + xlimits(ax, time, minval=xmin, maxval=xmax) + ax.set_ylim(0, upper) + hide_axis(ax, 'bottom') + hide_axis(ax, 'left') + return handles + +def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs): + y0 = low_ax.get_position().y0 + y1 = high_ax.get_position().y1 + transform = low_ax.transData + fig.transFigure.inverted() + x0 = transform.transform((zoom_abs[0], 0))[0] + x1 = transform.transform((zoom_abs[1], 0))[0] + rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, + transform=fig.transFigure, **kwargs) + fig.add_artist(rect) + return None + +def assign_colors(handles, types, colors): + for handle, type_id in zip(handles, types): + handle.set_color(colors[str(int(type_id))]) + return None + +def reorder_traces(handles, signal, zlow=2, zhigh=2.5): + inds = np.argsort(signal.std(axis=0)) + zorders = np.linspace(zlow, zhigh, len(inds))[::-1] + for ind, z in zip(inds, zorders): + handles[ind].set_zorder(z) + return None + +def choose_kernels(kern_specs, features, kern_types, per_type=2, thresh=0.01): + mean_feat = features.mean(axis=0) + feat_diff = np.abs(mean_feat[:, None] - mean_feat[None, :]) + feat_diff[features.max(axis=0) < thresh, :] = np.nan + feat_diff = np.nanmean(feat_diff, axis=0) + + ranking = np.argsort(feat_diff) + kern_inds = [] + for type_id in kern_types: + type_inds = np.nonzero(kern_specs[:, 0] == type_id)[0] + rank_inds = np.nonzero(np.isin(ranking, type_inds))[0][-per_type:] + kern_inds.extend(ranking[rank_inds]) + return np.array(kern_inds) + +def letter_subplots(axes, labels='abcd', x=0.02, y=1, ha='left', va='bottom', + fontsize=16, fontweight='bold', **kwargs): + for ax, label in zip(axes, labels): + ax.text(x, y, label, transform=ax.transAxes, ha=ha, va=va, + fontsize=fontsize, fontweight=fontweight, **kwargs) + return None \ No newline at end of file diff --git a/python/fig_pathway_stages.py b/python/fig_pathway_stages.py index 5df544e..4eb877c 100644 --- a/python/fig_pathway_stages.py +++ b/python/fig_pathway_stages.py @@ -2,153 +2,154 @@ import plotstyle_plt import glob import numpy as np import matplotlib.pyplot as plt -from itertools import product from thunderhopper.modeltools import load_data from color_functions import load_colors +from plot_functions import prepare_fig, hide_axis, letter_subplots,\ + ylabel, super_xlabel, plot_line, plot_barcode,\ + indicate_zoom, assign_colors, reorder_traces from IPython import embed -def prepare_fig(nrows, ncols, width=8, height=None, rheight=2, - left=0.01, right=0.95, bottom=0.01, top=0.95, - wspace=0.4, hspace=0.4): - if height is None: - height = rheight * nrows - fig = plt.figure(figsize=(width, height)) - grid = fig.add_gridspec(nrows=nrows, ncols=ncols, wspace=wspace, hspace=hspace, - left=left, right=right, top=top, bottom=bottom) - axes = np.zeros((nrows, ncols), dtype=object) - for i, j in product(range(nrows), range(ncols)): - axes[i, j] = fig.add_subplot(grid[i, j]) - axes[i, j].set_facecolor('none') - return fig, axes +# def prepare_fig(nrows, ncols, width=8, height=None, rheight=2, +# left=0.01, right=0.95, bottom=0.01, top=0.95, +# wspace=0.4, hspace=0.4): +# if height is None: +# height = rheight * nrows +# fig = plt.figure(figsize=(width, height)) +# grid = fig.add_gridspec(nrows=nrows, ncols=ncols, wspace=wspace, hspace=hspace, +# left=left, right=right, top=top, bottom=bottom) +# axes = np.zeros((nrows, ncols), dtype=object) +# for i, j in product(range(nrows), range(ncols)): +# axes[i, j] = fig.add_subplot(grid[i, j]) +# axes[i, j].set_facecolor('none') +# return fig, axes -def xlimits(ax, time, minval=None, maxval=None, pad=0.05): - limits = [minval, maxval] - if minval is None: - limits[0] = time[0] - if maxval is None: - limits[1] = time[-1] - if pad is not None and minval is None: - limits[0] -= (limits[1] - limits[0]) * pad - if pad is not None and maxval is None: - limits[1] += (limits[1] - limits[0]) * pad - return ax.set_xlim(limits) +# def xlimits(ax, time, minval=None, maxval=None, pad=0.05): +# limits = [minval, maxval] +# if minval is None: +# limits[0] = time[0] +# if maxval is None: +# limits[1] = time[-1] +# if pad is not None and minval is None: +# limits[0] -= (limits[1] - limits[0]) * pad +# if pad is not None and maxval is None: +# limits[1] += (limits[1] - limits[0]) * pad +# return ax.set_xlim(limits) -def ylimits(ax, signal, minval=None, maxval=None, pad=0.05): - limits = [minval, maxval] - if minval is None: - limits[0] = signal.min() - if maxval is None: - limits[1] = signal.max() - if pad is not None and minval is None: - limits[0] -= (limits[1] - limits[0]) * pad - if pad is not None and maxval is None: - limits[1] += (limits[1] - limits[0]) * pad - return ax.set_ylim(limits) +# def ylimits(ax, signal, minval=None, maxval=None, pad=0.05): +# limits = [minval, maxval] +# if minval is None: +# limits[0] = signal.min() +# if maxval is None: +# limits[1] = signal.max() +# if pad is not None and minval is None: +# limits[0] -= (limits[1] - limits[0]) * pad +# if pad is not None and maxval is None: +# limits[1] += (limits[1] - limits[0]) * pad +# return ax.set_ylim(limits) -def ylabel(ax, label, x=-0.23, fontsize=20): - ax.set_ylabel(label, fontsize=fontsize, rotation=0, ha='left', va='center') - ax.yaxis.set_label_coords(x, 0.5) - return None +# def ylabel(ax, label, x=-0.23, fontsize=20): +# ax.set_ylabel(label, fontsize=fontsize, rotation=0, ha='left', va='center') +# ax.yaxis.set_label_coords(x, 0.5) +# return None -def super_xlabel(label, fig, high_ax, low_ax, y=0.005, **kwargs): - x = (low_ax.get_position().x0 + high_ax.get_position().x1) / 2 - fig.supxlabel(label, x=x, y=y, **kwargs) - return None +# def super_xlabel(label, fig, high_ax, low_ax, y=0.005, **kwargs): +# x = (low_ax.get_position().x0 + high_ax.get_position().x1) / 2 +# fig.supxlabel(label, x=x, y=y, **kwargs) +# return None -def super_ylabel(label, fig, high_ax, low_ax, x=0.005, **kwargs): - y = (low_ax.get_position().y0 + high_ax.get_position().y1) / 2 - fig.supylabel(label, x=x, y=y, **kwargs) - return None +# def super_ylabel(label, fig, high_ax, low_ax, x=0.005, **kwargs): +# y = (low_ax.get_position().y0 + high_ax.get_position().y1) / 2 +# fig.supylabel(label, x=x, y=y, **kwargs) +# return None -def hide_axis(ax, side='bottom'): - ax.spines[side].set_visible(False) - params = {side: False, 'label' + side: False} - ax.tick_params(axis='x' if side in ['top', 'bottom'] else 'y', - which='both', **params) - return None +# def hide_axis(ax, side='bottom'): +# ax.spines[side].set_visible(False) +# params = {side: False, 'label' + side: False} +# ax.tick_params(axis='x' if side in ['top', 'bottom'] else 'y', +# which='both', **params) +# return None -def plot_line(ax, time, signal, ymin=None, ymax=None, xmin=None, xmax=None, - xpad=None, ypad=0.05, yloc=None, **kwargs): - handles = ax.plot(time, signal, **kwargs) - xlimits(ax, time, minval=xmin, maxval=xmax, pad=xpad) - ylimits(ax, signal, minval=ymin, maxval=ymax, pad=ypad) - ax.yaxis.set_major_locator(plt.MultipleLocator(yloc)) - return handles +# def plot_line(ax, time, signal, ymin=None, ymax=None, xmin=None, xmax=None, +# xpad=None, ypad=0.05, yloc=None, **kwargs): +# handles = ax.plot(time, signal, **kwargs) +# xlimits(ax, time, minval=xmin, maxval=xmax, pad=xpad) +# ylimits(ax, signal, minval=ymin, maxval=ymax, pad=ypad) +# ax.yaxis.set_major_locator(plt.MultipleLocator(yloc)) +# return handles -def plot_barcode(ax, time, binary, offset=0.5, xmin=None, xmax=None, **kwargs): - if xmin is None: - xmin = time[0] - if xmax is None: - xmax = time[-1] - lower, upper, handles = 0, 1, [] - for i in range(binary.shape[1]): - h = ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs) - handles.append(h) - if i < binary.shape[1] - 1: - lower += offset + 1 - upper += offset + 1 - xlimits(ax, time, minval=xmin, maxval=xmax) - ax.set_ylim(0, upper) - hide_axis(ax, 'bottom') - hide_axis(ax, 'left') - return handles +# def plot_barcode(ax, time, binary, offset=0.5, xmin=None, xmax=None, **kwargs): +# if xmin is None: +# xmin = time[0] +# if xmax is None: +# xmax = time[-1] +# lower, upper, handles = 0, 1, [] +# for i in range(binary.shape[1]): +# h = ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs) +# handles.append(h) +# if i < binary.shape[1] - 1: +# lower += offset + 1 +# upper += offset + 1 +# xlimits(ax, time, minval=xmin, maxval=xmax) +# ax.set_ylim(0, upper) +# hide_axis(ax, 'bottom') +# hide_axis(ax, 'left') +# return handles -def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs): - y0 = low_ax.get_position().y0 - y1 = high_ax.get_position().y1 - transform = low_ax.transData + fig.transFigure.inverted() - x0 = transform.transform((zoom_abs[0], 0))[0] - x1 = transform.transform((zoom_abs[1], 0))[0] - rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, - transform=fig.transFigure, **kwargs) - fig.add_artist(rect) - return None +# def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs): +# y0 = low_ax.get_position().y0 +# y1 = high_ax.get_position().y1 +# transform = low_ax.transData + fig.transFigure.inverted() +# x0 = transform.transform((zoom_abs[0], 0))[0] +# x1 = transform.transform((zoom_abs[1], 0))[0] +# rect = plt.Rectangle((x0, y0), x1 - x0, y1 - y0, +# transform=fig.transFigure, **kwargs) +# fig.add_artist(rect) +# return None -def assign_colors(handles, types, colors): - for handle, type_id in zip(handles, types): - handle.set_color(colors[str(int(type_id))]) - return None +# def assign_colors(handles, types, colors): +# for handle, type_id in zip(handles, types): +# handle.set_color(colors[str(int(type_id))]) +# return None -def reorder_traces(handles, signal, zlow=2, zhigh=2.5): - inds = np.argsort(signal.std(axis=0)) - zorders = np.linspace(zlow, zhigh, len(inds))[::-1] - for ind, z in zip(inds, zorders): - handles[ind].set_zorder(z) - return None +# def reorder_traces(handles, signal, zlow=2, zhigh=2.5): +# inds = np.argsort(signal.std(axis=0)) +# zorders = np.linspace(zlow, zhigh, len(inds))[::-1] +# for ind, z in zip(inds, zorders): +# handles[ind].set_zorder(z) +# return None -def choose_kernels(kern_specs, features, kern_types, per_type=2, thresh=0.01): - mean_feat = features.mean(axis=0) - feat_diff = np.abs(mean_feat[:, None] - mean_feat[None, :]) - feat_diff[features.max(axis=0) < thresh, :] = np.nan - feat_diff = np.nanmean(feat_diff, axis=0) +# def choose_kernels(kern_specs, features, kern_types, per_type=2, thresh=0.01): +# mean_feat = features.mean(axis=0) +# feat_diff = np.abs(mean_feat[:, None] - mean_feat[None, :]) +# feat_diff[features.max(axis=0) < thresh, :] = np.nan +# feat_diff = np.nanmean(feat_diff, axis=0) - ranking = np.argsort(feat_diff) - kern_inds = [] - for type_id in kern_types: - type_inds = np.nonzero(kern_specs[:, 0] == type_id)[0] - rank_inds = np.nonzero(np.isin(ranking, type_inds))[0][-per_type:] - kern_inds.extend(ranking[rank_inds]) - return np.array(kern_inds) +# ranking = np.argsort(feat_diff) +# kern_inds = [] +# for type_id in kern_types: +# type_inds = np.nonzero(kern_specs[:, 0] == type_id)[0] +# rank_inds = np.nonzero(np.isin(ranking, type_inds))[0][-per_type:] +# kern_inds.extend(ranking[rank_inds]) +# return np.array(kern_inds) -def letter_subplots(axes, labels='abcd', x=0.02, y=1, ha='left', va='bottom', - fontsize=16, fontweight='bold', **kwargs): - for ax, label in zip(axes, labels): - ax.text(x, y, label, transform=ax.transAxes, ha=ha, va=va, - fontsize=fontsize, fontweight=fontweight, **kwargs) - return None +# def letter_subplots(axes, labels='abcd', x=0.02, y=1, ha='left', va='bottom', +# fontsize=16, fontweight='bold', **kwargs): +# for ax, label in zip(axes, labels): +# ax.text(x, y, label, transform=ax.transAxes, ha=ha, va=va, +# fontsize=fontsize, fontweight=fontweight, **kwargs) +# return None # GENERAL SETTINGS: target = 'Omocestus_rufipes' data_paths = glob.glob(f'../data/processed/{target}*.npz') stages = ['filt', 'env', 'log', 'inv', 'conv', 'bi', 'feat'] -save_path = '../figures/' +save_path = None#'../figures/' # PLOT SETTINGS: fig_kwargs = dict( - width=16 / 2.54 * 2, - height=6 / 2.54 * 2, - rheight=2 / 2.54 * 2, + width=32, + height=12, ) grid_kwargs = dict( wspace=0.15, @@ -167,6 +168,12 @@ ylabels = dict( bi=r'$b_i$', feat=r'$f_i$' ) +ylab_kwargs = dict( + x=-0.23, + rotation=0, + ha='left', + va='center', +) colors = load_colors('../data/stage_colors.npz') lw_full = dict( filt=0.25, @@ -242,9 +249,10 @@ for data_path in data_paths: t_full = np.arange(data['filt'].shape[0]) / config['rate'] # Select kernel subset: - kern_inds = [np.nonzero((config['k_specs'] == k).all(1))[0][0] for k in kernels] + kern_specs = config['k_specs'] + kern_inds = [np.nonzero((kern_specs == k).all(1))[0][0] for k in kernels] kern_inds = np.array(kern_inds) - kernel_specs = config['k_specs'][kern_inds] + kern_specs = config['k_specs'][kern_inds, :] # Establish zoom frame: zoom_abs = zoom_rel * t_full[-1] @@ -258,7 +266,7 @@ for data_path in data_paths: # Bandpass-filtered signal: ax_full, ax_zoom = axes[0, :] - ylabel(ax_full, ylabels['filt']) + ylabel(ax_full, ylabels['filt'], **ylab_kwargs) plot_line(ax_full, t_full, data['filt'], c=colors['filt'], lw=lw_full['filt'], yloc=loc_full['filt']) plot_line(ax_zoom, t_zoom, data['filt'][zoom_mask], c=colors['filt'], lw=lw_zoom['filt'], yloc=loc_zoom['filt']) hide_axis(ax_full, 'bottom') @@ -266,7 +274,7 @@ for data_path in data_paths: # Signal envelope: ax_full, ax_zoom = axes[1, :] - ylabel(ax_full, ylabels['env']) + ylabel(ax_full, ylabels['env'], **ylab_kwargs) plot_line(ax_full, t_full, data['env'], ymin=0, c=colors['env'], lw=lw_full['env'], yloc=loc_full['env']) plot_line(ax_zoom, t_zoom, data['env'][zoom_mask], ymin=0, c=colors['env'], lw=lw_zoom['env'], yloc=loc_zoom['env']) hide_axis(ax_full, 'bottom') @@ -274,7 +282,7 @@ for data_path in data_paths: # Logarithmic envelope: ax_full, ax_zoom = axes[2, :] - ylabel(ax_full, ylabels['log']) + ylabel(ax_full, ylabels['log'], **ylab_kwargs) plot_line(ax_full, t_full, data['log'], ymax=0, c=colors['log'], lw=lw_full['log'], yloc=loc_full['log']) plot_line(ax_zoom, t_zoom, data['log'][zoom_mask], ymax=0, c=colors['log'], lw=lw_zoom['log'], yloc=loc_zoom['log']) hide_axis(ax_full, 'bottom') @@ -282,7 +290,7 @@ for data_path in data_paths: # Adapted envelope: ax_full, ax_zoom = axes[3, :] - ylabel(ax_full, ylabels['inv']) + ylabel(ax_full, ylabels['inv'], **ylab_kwargs) plot_line(ax_full, t_full, data['inv'], c=colors['inv'], lw=lw_full['inv'], yloc=loc_full['inv']) plot_line(ax_zoom, t_zoom, data['inv'][zoom_mask], c=colors['inv'], lw=lw_zoom['inv'], yloc=loc_zoom['inv']) @@ -302,34 +310,34 @@ for data_path in data_paths: # Convolutional filter responses: ax_full, ax_zoom = axes[0, :] - ylabel(ax_full, ylabels['conv']) + ylabel(ax_full, ylabels['conv'], **ylab_kwargs) signal = data['conv'][:, kern_inds] handles = plot_line(ax_full, t_full, signal, lw=lw_full['conv'], yloc=loc_full['conv']) - assign_colors(handles, kernel_specs[:, 0], conv_colors) + assign_colors(handles, kern_specs[:, 0], conv_colors) reorder_traces(handles, signal) handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], lw=lw_zoom['conv'], yloc=loc_zoom['conv']) - assign_colors(handles, kernel_specs[:, 0], conv_colors) + assign_colors(handles, kern_specs[:, 0], conv_colors) reorder_traces(handles, signal[zoom_mask, :]) hide_axis(ax_full, 'bottom') hide_axis(ax_zoom, 'bottom') # Binary responses: ax_full, ax_zoom = axes[1, :] - ylabel(ax_full, ylabels['bi']) + ylabel(ax_full, ylabels['bi'], **ylab_kwargs) signal = data['bi'][:, kern_inds] handles = plot_barcode(ax_full, t_full, signal, lw=lw_full['bi']) - assign_colors(handles, kernel_specs[:, 0], bi_colors) + assign_colors(handles, kern_specs[:, 0], bi_colors) handles = plot_barcode(ax_zoom, t_zoom, signal[zoom_mask, :], lw=lw_zoom['bi']) - assign_colors(handles, kernel_specs[:, 0], bi_colors) + assign_colors(handles, kern_specs[:, 0], bi_colors) # Finalized features: ax_full, ax_zoom = axes[2, :] - ylabel(ax_full, ylabels['feat']) + ylabel(ax_full, ylabels['feat'], **ylab_kwargs) signal = data['feat'][:, kern_inds] handles = plot_line(ax_full, t_full, signal, ymin=0, ymax=1, c=colors['feat'], lw=lw_full['feat'], yloc=loc_full['feat']) - assign_colors(handles, kernel_specs[:, 0], feat_colors) + assign_colors(handles, kern_specs[:, 0], feat_colors) handles = plot_line(ax_zoom, t_zoom, signal[zoom_mask, :], ymin=0, ymax=1, c=colors['feat'], lw=lw_zoom['feat'], yloc=loc_zoom['feat']) - assign_colors(handles, kernel_specs[:, 0], feat_colors) + assign_colors(handles, kern_specs[:, 0], feat_colors) # Posthoc adjustments: ax_full.set_xlim(t_full[0], t_full[-1]) @@ -340,5 +348,3 @@ for data_path in data_paths: if save_path is not None: fig.savefig(f'{save_path}fig_feat_stages.pdf') plt.show() - - diff --git a/python/plot_functions.py b/python/plot_functions.py new file mode 100644 index 0000000..c2a24f4 --- /dev/null +++ b/python/plot_functions.py @@ -0,0 +1,128 @@ +import string +import numpy as np +import matplotlib.pyplot as plt +from itertools import product + +def prepare_fig(nrows, ncols, width=8, height=None, rheight=2, unit=1/2.54, + left=0.01, right=0.95, bottom=0.01, top=0.95, + wspace=0.4, hspace=0.4): + if height is None: + height = rheight * nrows + fig = plt.figure(figsize=(width * unit, height * unit)) + grid = fig.add_gridspec(nrows=nrows, ncols=ncols, wspace=wspace, hspace=hspace, + left=left, right=right, top=top, bottom=bottom) + axes = np.zeros((nrows, ncols), dtype=object) + for i, j in product(range(nrows), range(ncols)): + axes[i, j] = fig.add_subplot(grid[i, j]) + axes[i, j].set_facecolor('none') + return fig, axes + +def hide_axis(ax, side='bottom'): + ax.spines[side].set_visible(False) + params = {side: False, 'label' + side: False} + ax.tick_params(axis='x' if side in ['top', 'bottom'] else 'y', + which='both', **params) + return None + +def letter_subplots(axes, labels=None, x=0.02, y=1, ha='left', va='bottom', + fontsize=16, fontweight='bold', **kwargs): + if labels is None: + labels = string.ascii_lowercase + for ax, label in zip(axes, labels): + ax.text(x, y, label, transform=ax.transAxes, ha=ha, va=va, + fontsize=fontsize, fontweight=fontweight, **kwargs) + return None + +def xlimits(ax, time, minval=None, maxval=None, pad=0.05): + limits = [minval, maxval] + if minval is None: + limits[0] = time[0] + if maxval is None: + limits[1] = time[-1] + span = limits[1] - limits[0] + if pad and minval is None: + limits[0] -= span * pad + if pad and maxval is None: + limits[1] += span * pad + return ax.set_xlim(limits) + +def ylimits(ax, signal, minval=None, maxval=None, pad=0.05): + limits = [minval, maxval] + if minval is None: + limits[0] = signal.min() + if maxval is None: + limits[1] = signal.max() + span = limits[1] - limits[0] + if pad and minval is None: + limits[0] -= span * pad + if pad and maxval is None: + limits[1] += span * pad + return ax.set_ylim(limits) + +def xlabel(ax, label, y=-0.1, fontsize=20, **kwargs): + ax.set_xlabel(label, fontsize=fontsize, **kwargs) + ax.xaxis.set_label_coords(0.5, y) + return None + +def ylabel(ax, label, x=-0.2, fontsize=20, **kwargs): + ax.set_ylabel(label, fontsize=fontsize, **kwargs) + ax.yaxis.set_label_coords(x, 0.5) + return None + +def super_xlabel(label, fig, high_ax, low_ax, y=0.005, **kwargs): + x = (low_ax.get_position().x0 + high_ax.get_position().x1) / 2 + fig.supxlabel(label, x=x, y=y, **kwargs) + return None + +def super_ylabel(label, fig, high_ax, low_ax, x=0.005, **kwargs): + y = (low_ax.get_position().y0 + high_ax.get_position().y1) / 2 + fig.supylabel(label, x=x, y=y, **kwargs) + return None + +def plot_line(ax, time, signal, ymin=None, ymax=None, xmin=None, xmax=None, + xpad=None, ypad=0.05, yloc=None, xloc=None, **kwargs): + handles = ax.plot(time, signal, **kwargs) + xlimits(ax, time, minval=xmin, maxval=xmax, pad=xpad) + ylimits(ax, signal, minval=ymin, maxval=ymax, pad=ypad) + if xloc is not None: + ax.xaxis.set_major_locator(plt.MultipleLocator(xloc)) + if yloc is not None: + ax.yaxis.set_major_locator(plt.MultipleLocator(yloc)) + return handles + +def plot_barcode(ax, time, binary, offset=0.5, xmin=None, xmax=None, **kwargs): + lower, upper, handles = 0, 1, [] + for i in range(binary.shape[1]): + h = ax.fill_between(time, lower, upper, where=binary[:, i], **kwargs) + handles.append(h) + if i < binary.shape[1] - 1: + lower += offset + 1 + upper += offset + 1 + xlimits(ax, time, minval=xmin, maxval=xmax, pad=0) + ax.set_ylim(0, upper) + hide_axis(ax, 'bottom') + hide_axis(ax, 'left') + return handles + +def indicate_zoom(fig, high_ax, low_ax, zoom_abs, **kwargs): + y0 = low_ax.get_position().y0 + y1 = high_ax.get_position().y1 + transform = low_ax.transData + fig.transFigure.inverted() + x0 = transform.transform((zoom_abs[0], 0))[0] + x1 = transform.transform((zoom_abs[1], 0))[0] + fig.add_artist(plt.Rectangle((x0, y0), x1 - x0, y1 - y0, + transform=fig.transFigure, **kwargs)) + return None + +def assign_colors(handles, types, colors): + for handle, type_id in zip(handles, types): + handle.set_color(colors[str(int(type_id))]) + return None + +def reorder_traces(handles, signal, zlow=2, zhigh=2.5): + inds = np.argsort(signal.std(axis=0)) + zorders = np.linspace(zlow, zhigh, len(inds))[::-1] + for ind, z in zip(inds, zorders): + handles[ind].set_zorder(z) + return None +